Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flax whisper implementation #20479

Merged
merged 125 commits into from
Feb 20, 2023
Merged

Conversation

andyehrenberg
Copy link
Contributor

@andyehrenberg andyehrenberg commented Nov 28, 2022

Adds Flax whisper implementations, and adjusts flax generation utils to support it.

@ydshieh @ArthurZucker

See discussion in #19512

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 29, 2022

@andyehrenberg

Thank you for the PR. However, a pull request should focus on a single objective/goal, rather than changing multiple things at the same time which are not absolutely coupled.

Please

  • follow the pytorch implementation regarding the past_key_values
  • revert the changes on the flax generation utils
    (You may want to have a backup branch to save these changes for future pull requests.)

The goal of this PR is to add Flax implementation of Whisper. For other changes, it's better to open issue tickets, and if we all agree with the proposals, a PR could proceed :-)

Thank you!

@andyehrenberg
Copy link
Contributor Author

andyehrenberg commented Nov 29, 2022

I see a few other instances in this repo where the pytorch implementation computes past_key_values_length while the flax implementation uses position_ids (BART, OPT, etc) - to me, keeping consistency among the APIs of the flax models is something we should strive for. What do you think @ydshieh @patrickvonplaten ?

Happy to remove the changes to the generation stuff and open a separate PR for that - will definitely do this to make flax Whisper generation work!

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 29, 2022

I wasn't aware of that inconsistency, thank you for pointing out. This is a good question! But I don't think that's a very serious problem so far - the most important thing is the different frameworks produce the same outputs when feeding the same (supported) inputs + the API on the top model levels being consistent.

(The internal computation could be somehow different - if there is good reason)

In any case, this could be discussed in an issue and we can proceed with a PR once decided :-)

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 29, 2022

BTW, there is some issue for triggering CircleCI. The message is

Could not find a usable config.yml, you may have revoked the CircleCI OAuth app.
Please sign out of CircleCI and log back in with your VCS before triggering a new pipeline.

Do you use some IDE to push the commits? Could you try to push the commit with a commandline tool or some git GUI tools instead?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 29, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Also cc @sanchit-gandhi

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Dec 2, 2022

Hey! Thanks for opening the follow PR 🤗

I don't think I agree with @ydshieh here, adding the flax_generation_utils along with whisper totally makes sense as it was done for pytorch and tf, and is required to add the generation tests which are currently missing!
Regarding the past_key_values, we don't really strive to match transformers with other APIs, rather I think we prefer consistency within our own library, and code clarity.
However you can still open an issue and we can discuss whether we should refactor the design of past_key_values for our flax model!

Will have a look at the PR 😉

@ArthurZucker ArthurZucker self-assigned this Dec 2, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented Dec 2, 2022

You are right! I am not aware of those generation features are introduced when you added Whisper @ArthurZucker . Sorry about that, @andyehrenberg !

@sanchit-gandhi
Copy link
Contributor

Super excited by this PR! 🚀 Feel free to tag me with questions / review requests as well @andyehrenberg 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work there!
Not really think we are gonna push for the scan methods, but it is debatable. @sgugger correct me if I am wrong

src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
Comment on lines 707 to 712
if attention_mask is not None:
if position_ids is None:
position_ids = attention_mask.cumsum(-1) - 1
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great if we could follow the simple logic that we have in the pytorch version where we use the input_ids with self.embed_positions(input_ids, past_key_values_length=past_key_values_length.

Copy link
Contributor Author

@andyehrenberg andyehrenberg Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should stick with computing position_ids to keep a similar api to the other flax models, and because this better handles the scenario where we have a batch to run generation for with different decoder prompt lengths. The pytorch version ends up just using past_key_values_length to compute something akin to position_ids, but we can just use the attention_mask to figure them out. I'd actually argue we should change the pytorch whisper implementation to use position_ids, because as it currently stands it'll fail to decode batches of varying decoder prompt lengths - it should take more inspiration from the decoder-only models that compute position_ids as opposed to the encoder-decoder models that don't assume decoder prefixes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @andyehrenberg that we should use the Flax implementation here. However, it would be better still in terms of Flax compatibility if this logic went under the decode's __call__ method, rather than under FlaxWhisperDecoder (as we do in Flax MBart for example)

src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

Also sorry! We just modified Whisper quit a bit 😅

@andyehrenberg
Copy link
Contributor Author

andyehrenberg commented Jan 26, 2023

Also sorry! We just modified Whisper quit a bit 😅

@ArthurZucker - Doesn't actually look too bad to catch up with those changes! Can do that soon-ish. I already have a jax timestamp processor that's compilable.

@sanchit-gandhi
Copy link
Contributor

Oh no - sorry you have to iterate again here @andyehrenberg! Feel free to ping me with any questions / discussions - more than happy to help with the final sprint of the integration! Otherwise super excited to review a final time before merge! 🚀

@andyehrenberg
Copy link
Contributor Author

@sanchit-gandhi - I think this is ready for another look - the recent commits (I think) get us to feature parity with the torch version.

@andyehrenberg
Copy link
Contributor Author

@sanchit-gandhi Bump

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow! Very clean, thanks a lot for the long work! I just left 1 comment on testing the timestamp generation but should be good to merge otherwise! cc @sanchit-gandhi

src/transformers/generation/flax_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/flax_utils.py Show resolved Hide resolved
# fmt: on

transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the test_tiny_timestamp_generation where you can test if jit compile produces the correct timestamps?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to make sure that the logit processor correctly predicts them. I speak from TF experience, my code worked but when compiling it started failing 😓

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added - some local sanity checks were working for me under jit compilation at least!

@andyehrenberg
Copy link
Contributor Author

@sanchit-gandhi @ArthurZucker - Addressed Arthur's comments and cleaned up the timestamp logits processor a bit. Hopefully we're close to getting this merged!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice @andyehrenberg! Thanks for iterating here - reviewed the new changes and the PR is looking super clean. Last request from me is if we can avoid defining the if_true() functions if possible and just add the code explicitly! Good for merge otherwise :)

src/transformers/generation/flax_logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/flax_logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/flax_logits_process.py Outdated Show resolved Hide resolved
@andyehrenberg
Copy link
Contributor Author

Very nice @andyehrenberg! Thanks for iterating here - reviewed the new changes and the PR is looking super clean. Last request from me is if we can avoid defining the if_true() functions if possible and just add the code explicitly! Good for merge otherwise :)

For sure, made those changes :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for your contribution!

@sgugger sgugger merged commit 2840272 into huggingface:main Feb 20, 2023
@andyehrenberg andyehrenberg deleted the flax_whisper branch February 28, 2023 20:31
ArthurZucker added a commit to ArthurZucker/transformers that referenced this pull request Mar 2, 2023
* add flax whisper implementation

* rever change to setup

* remove unused imports

* revert generation changes

* flax whisper docs

* docs

* import order

* import sorting

* isort

* add dummy objects

* doc formatting

* formatting

* remove trailing whitespaces

* fix flax whisper docs

* add generation logic to unlock flax whisper

* remove scans

* give credits to Flax Bart implementation

* remove unused imports

* add license

* remove assert

* more credits to Bart

* fix style

* formatting

* support left padding

* add flax whisper generation test

* remove copied from comments whenever not a full copy

* fix docstrings for logits processors

* revert change to FlaxForceTokensLogitsProcessor

* revert doc changes

* improve generation docs

* reorganize

* formatting

* cleanup docs

* add tests

* handle empty list case

* fix forced decoder ids in flax tests

* add flax whisper to inits

* upate dummy objects

* docs for FlaxAutoModelForSpeechSeq2Seq

* fix decoder_position_ids computation in pretrained model decode/__call__ fns

* add Copied from statements as necessary

* compute position_ids only in __call__ and decode methods of pretrained model subclasses

* improve readabilityof compute positional embeddings

* check dimensionality of input_features instead of hidden_states

* copied from statement for init_cache

* formatting

* fix copies

* fix copies

* pass attention mask to encoder layers

* fix decoder module outputs

* set dtype

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* smaller flax model for whisper test

* Update src/transformers/generation/flax_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/models/whisper/test_modeling_flax_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cleanup

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* bias cleanup

* doc fix

* align style for force tokens processor

* readability

* fix input shape in tests

* revert FlaxGenerationMixin docstring

* formatting

* fix tests

* fix imports

* consistent encoder hidden states

* consistent hidden states

* input shapes

* typo

* partial class trick

* partial class for input shape

* base_class with correct input shape

* partial base classes

* match by name

* set main_input_name

* compare on names

* formatting

* remove unused import

* safer position ids computation

* safer position id computation

* Update src/transformers/models/whisper/modeling_flax_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_flax_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove identical inherited tests

* fix prompt ids in tests

* use generation config

* use jnp array

* better var names

* more explicit bias use

* import transformers

* formatting

* test formatting

* remove unused imports

* remove unused imports

* formatting

* isort

* docs

* fix ln orders for encoder hidden states

* whisper unique generation stuff

* flake

* use finfo for attention bias

* docs

* Update src/transformers/generation/flax_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* docs

* add timestamp flax test

* jit for timestamps

* formatting

* clean up timestamps processor

* formatting

* remove if_true

* cleanup

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@Tungbillee
Copy link

Is there any instructions to open the google cloud TPU port, admin?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Flax Whisper
9 participants