-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BLOOM Flax #18022
BLOOM Flax #18022
Conversation
…to add_bloom_flax
…bloom_flax # Conflicts: # src/transformers/models/bloom/modeling_flax_bloom.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there! Just some small refactoring suggestions to clean the code up a bit
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
…nsformers into add_bloom_flax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the previous comments @younesbelkada. Two things from this round of review:
- Could we tidy up the
build_alibi_tensor_flax
function to avoid a triple nested function? Much of the logic can be copied over from PyTorch BLOOM! - Big question for me is whether we keep
scan
or not - I'm in favour of removing it for Flax BLOOM (see comments below)
all_attentions = () if output_attentions else None | ||
all_hidden_states = () if output_hidden_states else None | ||
|
||
if self.use_scan: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've currently left scan
in the modelling code. Part of me thinks we should remove it for Transformers for the following reasons:
scan
adds a lot of boilerplate code that isn't very easy to understand- It is only beneficial for compile times when the model size is large, mostly when training and less so for inference
In the latter case, users will also likely shard the model, meaning they can employ the standalone code in bloom-jax-inference where we can retain scan
functionality.
We also found generation time to be slower when using scan
vs not using it (despite a faster compile time). The generation time will amortise the compile time in any use case of Flax BLOOM.
Given that the philosophy of Transformers is functional, easy-to-understand code that is not necessarily fully optimised, I'm in favour of stripping scan
from Flax BLOOM and leaving it to bloom-jax-inference to serve users that want to deploy larger variants of the model.
- remove unused code - refactor a bit - revert import `torch`
- change build alibi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have addressed your new suggestions @sanchit-gandhi ! Here I mainly focused on refactoring a bit the build_alibi
function to match the implementational style of Pytorch
-> this way it seems to be more readable !
I will leave you and @patil-suraj and @patrickvonplaten decide regarding the scan feature and happy to remove it once we agree on that!
Can also confirm the slow tests/conversion tests pass ;)
d789a85
to
dcdd563
Compare
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Should we merge this one ? cc @patrickvonplaten @patil-suraj @sanchit-gandhi |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Both #17761 and this PR show a lot of work. Why isn't it merged? |
Adding this to my TODOs |
What does this PR do?
An attempt of adding Flax implementation of BLOOM - original PR from @haileyschoelkopf #17761
TODOs: