Skip to content

Commit

Permalink
[Flax Whisper] Update decode docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jun 1, 2023
1 parent fabe17a commit ceebff9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,16 +1017,17 @@ def decode(
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> import jax.numpy as jnp
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
Expand Down

0 comments on commit ceebff9

Please sign in to comment.