Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Embedding weight tying (#169) #172

Merged
merged 3 commits into from
Jan 5, 2022
Merged

[feature] Embedding weight tying (#169) #172

merged 3 commits into from
Jan 5, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Jan 3, 2022

What does this PR do?

Tentative implementation of #169, fairly minor, with a matching unit test update.
cc @erip
See for a reference and more context

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2022
@blefaudeux blefaudeux marked this pull request as draft January 3, 2022 22:40
@codecov-commenter
Copy link

Codecov Report

Merging #172 (77fa504) into main (154b819) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #172      +/-   ##
==========================================
+ Coverage   90.56%   90.58%   +0.01%     
==========================================
  Files          56       56              
  Lines        2829     2835       +6     
==========================================
+ Hits         2562     2568       +6     
  Misses        267      267              
Flag Coverage Δ
Python 90.58% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/factory/model_factory.py 97.82% <100.00%> (+0.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 70019c4...77fa504. Read the comment docs.

@blefaudeux blefaudeux marked this pull request as ready for review January 4, 2022 01:35
@erip
Copy link
Contributor

erip commented Jan 4, 2022

Looks good! Huge thanks, @blefaudeux. I can pull in the changes to my project using xformers and try it out.

@erip
Copy link
Contributor

erip commented Jan 4, 2022

Looks like there's a tiny bit of performance improvement (on my silly CPU machine):

➜  python train.py tie
Epoch 1 step: 1 Loss: 9.666109 Took 5.800530 seconds. bsz (toks): 2438
Epoch 1 step: 2 Loss: 8.875856 Took 15.910649 seconds. bsz (toks): 3595
Epoch 1 step: 3 Loss: 7.488206 Took 15.342067 seconds. bsz (toks): 3866
^C
...
➜  python train.py
Epoch 1 step: 1 Loss: 9.688322 Took 5.884865 seconds. bsz (toks): 2438
Epoch 1 step: 2 Loss: 8.820903 Took 16.065957 seconds. bsz (toks): 3595
Epoch 1 step: 3 Loss: 7.255448 Took 15.518760 seconds. bsz (toks): 3866

I think this is generally the right direction. I can also take a look at comparisons of memory util between them, too. That said, printing out the number of trainable parameters shows that this seems to work well:

Tied: There are 25,588,799 trainable parameters.
Untied: There are 33,359,807 trainable parameters.

@blefaudeux
Copy link
Contributor Author

Looks like there's a tiny bit of performance improvement (on my silly CPU machine):

➜  python train.py tie
Epoch 1 step: 1 Loss: 9.666109 Took 5.800530 seconds. bsz (toks): 2438
Epoch 1 step: 2 Loss: 8.875856 Took 15.910649 seconds. bsz (toks): 3595
Epoch 1 step: 3 Loss: 7.488206 Took 15.342067 seconds. bsz (toks): 3866
^C
...
➜  python train.py
Epoch 1 step: 1 Loss: 9.688322 Took 5.884865 seconds. bsz (toks): 2438
Epoch 1 step: 2 Loss: 8.820903 Took 16.065957 seconds. bsz (toks): 3595
Epoch 1 step: 3 Loss: 7.255448 Took 15.518760 seconds. bsz (toks): 3866

I think this is generally the right direction. I can also take a look at comparisons of memory util between them, too. That said, printing out the number of trainable parameters shows that this seems to work well:

Tied: There are 25,588,799 trainable parameters.
Untied: There are 33,359,807 trainable parameters.

would you have a small enough task in mind ? It could be added to the examples and can be useful for sanity checking and perf regression catching.

@erip
Copy link
Contributor

erip commented Jan 4, 2022

This example is somewhat involved (machine translation), but I could probably make something smaller. If that's of interest, I'm happy to try to contribute something!

@blefaudeux
Copy link
Contributor Author

This example is somewhat involved (machine translation), but I could probably make something smaller. If that's of interest, I'm happy to try to contribute something!

just if it's not too much work ! There are two examples here if that helps. Also, you'll really need a GPU at some point :D

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@blefaudeux blefaudeux merged commit 3422b41 into main Jan 5, 2022
@blefaudeux blefaudeux deleted the weight_tying branch January 18, 2022 04:46
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants