Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding inputs_embeds argument and switch to paddle.nn.TransformerEncoder for Electra models #3401

Merged
merged 19 commits into from
Nov 3, 2022

Conversation

sijunhe
Copy link
Collaborator

@sijunhe sijunhe commented Oct 1, 2022

PR types

Function optimization

PR changes

APIs

Description

Addressing part of #3382

  • Adding inputs_embeds argument for Electra models to provide more control over how to convert input_ids indices
    into the embedding space. This is particularly useful for use cases such as P-Tuning.
  • Remove TransformerEncoderPro and switch to paddle.nn.TransformerEncoder

@CLAassistant
Copy link

CLAassistant commented Oct 1, 2022

CLA assistant check
All committers have signed the CLA.

guoshengCS
guoshengCS previously approved these changes Oct 11, 2022
Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@wj-Mcat wj-Mcat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your PR looks so great, and let's discuss about two small suggestions. Waiting for your comments.

Comment on lines 208 to 211
if input_ids is not None:
input_embeddings = self.word_embeddings(input_ids)
else:
input_embeddings = inputs_embeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code block can be improved to:

        if input_ids is None:
            inputs_embeds = self.word_embeddings(input_ids)

and in the forward method, rename input_embeddings to inputs_embeds. In this way, the code looks more concise. how do you think about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the input_embeddings = self.word_embeddings(input_ids) in following original code be removed

        if token_type_ids is None:
            token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = input_embeddings + position_embeddings + token_type_embeddings

Copy link
Collaborator Author

@sijunhe sijunhe Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review folks.
@wj-Mcat While I agree that renaming input_embeddings to inputs_embeds makes the code more concise, it also makes it less explicit/readable. Therefore I prefer the way it is now.
@guoshengCS good call. Removed the redundant line of code.

Comment on lines 81 to 86
inputs_embeds = None
if self.use_inputs_embeds:
inputs_embeds = floats_tensor(
[self.batch_size, self.seq_length, self.embedding_size])
# In order to use inputs_embeds, input_ids needs to set to None
input_ids = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not use_inputs_embeds, it should not prepare the input_ids tensor in prepare_config_and_inputs method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

@sijunhe
Copy link
Collaborator Author

sijunhe commented Oct 11, 2022

Addressed both comments. I think this PR should be ready for merging.

guoshengCS
guoshengCS previously approved these changes Oct 12, 2022
@sijunhe sijunhe requested a review from wj-Mcat October 13, 2022 08:33
@wj-Mcat
Copy link
Contributor

wj-Mcat commented Oct 17, 2022

There is another to-do under the TransformerEncoderPro class, you can refer to :

if cache is None and getattr(self, "_use_cache", False):
cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers)
# To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts
# to True when cache is not None.
new_caches = [] if cache is not None and getattr(self, "_use_cache",
True) else None

you set the _use_cache attribute in ElectraModel but not implementing the handler for it. So this block should be changed, how do you think about it? @sijunhe

@sijunhe
Copy link
Collaborator Author

sijunhe commented Oct 17, 2022

There is another to-do under the TransformerEncoderPro class, you can refer to :

if cache is None and getattr(self, "_use_cache", False):
cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers)
# To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts
# to True when cache is not None.
new_caches = [] if cache is not None and getattr(self, "_use_cache",
True) else None

you set the _use_cache attribute in ElectraModel but not implementing the handler for it. So this block should be changed, how do you think about it? @sijunhe

Good catch!
As this to-do pertains to the use of cache and past_key_values, it is outside the scope of this PR. To incorporate the change you asked, I'd like to change not only the handle, but the unit test as well. I think we should merge this PR first and I'll create a new PR for the implementation.

Regarding the to-do:
I looked at _transformer_encoder_fwd you linked and the TransformerEncoderPro in ELECTRA and seems like they are identical. I should be able to directly use the existing patched paddle.nn.Transformer instead of creating one for electra, right?

@wj-Mcat
Copy link
Contributor

wj-Mcat commented Oct 17, 2022

There are some thing I want to tell you:

  1. the implementation of paddle.nn.Transformer** don't contian full features we wanted, so here is why TransformerEncoderPro was born.
  2. _transformer_encoder_fwd the TransformerEncoderPro are actually identical and should be refactored into paddlenlp/layers/transformer.py as the TransformerEncoderPro class.

and there are some modules that using paddle.nn.TransformerEncoderpaddle.nn.TransformerEncoderLayerpaddle.nn.TransformerDecoderpaddle.nn.TransformerDecoderLayer, you can change the related module to make it more unified.

I prefer that you do it in this pr. how do you think about it? @sijunhe @guoshengCS

@wj-Mcat
Copy link
Contributor

wj-Mcat commented Oct 17, 2022

In order to make this pr merged, you can make some changes in TransformerEncoderPro class under the electra.modeling module. The works of refactoring can be done in next few weeks. @sijunhe

@sijunhe
Copy link
Collaborator Author

sijunhe commented Oct 17, 2022

In order to make this pr merged, you can make some changes in TransformerEncoderPro class under the electra.modeling module. The works of refactoring can be done in next few weeks. @sijunhe

I noticed that before #3411, TransformerEncoderPro is basically paddle.nn.TransformerEncoder without the cache input and output. By adding the cache functionalities in #3411 and #3401, TransformerEncoderPro would be identical to paddle.nn.TransformerEncoder. Hence I think we can just use paddle.nn.TransformerEncoder in electra

@sijunhe sijunhe changed the title Adding inputs_embeds argument for Electra models Adding inputs_embeds argument and switch to paddle.nn.TransformerEncoder for Electra models Oct 17, 2022
Copy link
Contributor

@wj-Mcat wj-Mcat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@guoshengCS guoshengCS merged commit c38902e into PaddlePaddle:develop Nov 3, 2022
@sijunhe sijunhe deleted the electra_inputs_embeds branch November 3, 2022 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants