Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModelingOutput] add tinybert/Electra/XLNet/ALBERT/ERNIE-M more output & loss #3148

Merged
merged 20 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
642 changes: 419 additions & 223 deletions paddlenlp/transformers/albert/modeling.py

Large diffs are not rendered by default.

373 changes: 317 additions & 56 deletions paddlenlp/transformers/electra/modeling.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion paddlenlp/transformers/electra/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
"electra-small": 512,
"electra-base": 512,
"electra-large": 512,
"chinese-electra-small": 512,
"chinese-electra-base": 512,
"chinese-electra-small": 512,
"ernie-health-chinese": 512
}

Expand Down
223 changes: 198 additions & 25 deletions paddlenlp/transformers/ernie_gram/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

from ..ernie.modeling import ErniePooler
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
BaseModelOutputWithPooling,
SequenceClassifierOutput,
TokenClassifierOutput,
QuestionAnsweringModelOutput,
)

__all__ = [
'ErnieGramModel',
Expand Down Expand Up @@ -239,7 +245,10 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand Down Expand Up @@ -272,6 +281,15 @@ def forward(self,
We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word,
"使" and "用" will have the same value.
Defaults to `None`, which means nothing needed to be prevented attention to.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.

Returns:
tuple: Returns tuple (``sequence_output``, ``pooled_output``).
Expand Down Expand Up @@ -317,10 +335,28 @@ def forward(self,
embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if paddle.is_tensor(encoder_outputs):
wj-Mcat marked this conversation as resolved.
Show resolved Hide resolved
encoder_outputs = (encoder_outputs, )

sequence_output = encoder_outputs[0]

pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions)


class ErnieGramForTokenClassification(ErnieGramPretrainedModel):
Expand Down Expand Up @@ -359,7 +395,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -370,6 +410,17 @@ def forward(self,
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
labels (Tensor of shape `(batch_size, sequence_length)`, optional):
Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.

Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Expand All @@ -388,14 +439,35 @@ def forward(self,
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
"""
sequence_output, _ = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits

loss = None
if labels is not None:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))
if not return_dict:
output = (logits, ) + outputs[2:]
return ((loss, ) + output) if loss is not None else (
output[0] if len(output) == 1 else output)

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class ErnieGramForQuestionAnswering(ErnieGramPretrainedModel):
Expand All @@ -419,7 +491,12 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
start_positions=None,
end_positions=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -430,6 +507,23 @@ def forward(self,
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
start_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.


Returns:
Expand Down Expand Up @@ -459,16 +553,47 @@ def forward(self,
logits = model(**inputs)
"""

sequence_output, _ = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

logits = self.classifier(sequence_output)
logits = self.classifier(outputs[0])
logits = paddle.transpose(logits, perm=[2, 0, 1])
start_logits, end_logits = paddle.unstack(x=logits, axis=0)

return start_logits, end_logits
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if start_positions.ndim > 1:
start_positions = start_positions.squeeze(-1)
if start_positions.ndim > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = paddle.shape(start_logits)[1]
start_positions = start_positions.clip(0, ignored_index)
end_positions = end_positions.clip(0, ignored_index)

loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss, ) +
output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class ErnieGramForSequenceClassification(ErnieGramPretrainedModel):
Expand Down Expand Up @@ -501,7 +626,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -511,7 +640,22 @@ def forward(self,
position_ids (Tensor, optional):
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
See :class:`BertModel`.
labels (Tensor of shape `(batch_size,)`, optional):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1`
a regression loss is computed (Mean-Square loss), If `num_classes > 1`
a classification loss is computed (Cross-Entropy).
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.


Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Expand All @@ -531,11 +675,40 @@ def forward(self,
logits = model(**inputs)

"""
_, pooled_output = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

pooled_output = self.dropout(outputs[1])
logits = self.classifier(pooled_output)

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits

loss = None
if labels is not None:
if self.num_classes == 1:
loss_fct = paddle.nn.MSELoss()
loss = loss_fct(logits, labels)
elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))
else:
loss_fct = paddle.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits, ) + outputs[2:]
return ((loss, ) + output) if loss is not None else (
output[0] if len(output) == 1 else output)

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading