From ceebff9a42668bdc384ef438a483eb52a8036267 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 31 May 2023 15:37:57 +0100 Subject: [PATCH] [Flax Whisper] Update decode docstring --- src/transformers/models/whisper/modeling_flax_whisper.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 5fbc158a44a02a..c9f3bbba9c3c96 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1017,16 +1017,17 @@ def decode( ```python >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration >>> from datasets import load_dataset + >>> import jax.numpy as jnp >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") - >>> input_features = inputs.input_features + >>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features + >>> encoder_outputs = model.encode(input_features=input_features) >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + >>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> last_decoder_hidden_states = outputs.last_hidden_state