Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flax whisper implementation #20479

Merged
merged 125 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 116 commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
7d3b6ef
add flax whisper implementation
andyehrenberg Nov 28, 2022
a9bed4c
rever change to setup
andyehrenberg Nov 28, 2022
0312993
remove unused imports
andyehrenberg Nov 28, 2022
c71fe4f
revert generation changes
andyehrenberg Nov 29, 2022
828d800
flax whisper docs
andyehrenberg Nov 29, 2022
baafb1c
docs
andyehrenberg Dec 1, 2022
7dba8b5
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 1, 2022
2da5a58
import order
andyehrenberg Dec 1, 2022
5ee9c1f
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 1, 2022
00f695f
import sorting
andyehrenberg Dec 1, 2022
0ecc03b
isort
andyehrenberg Dec 1, 2022
f66a005
add dummy objects
andyehrenberg Dec 1, 2022
175f344
doc formatting
andyehrenberg Dec 1, 2022
3329e6c
formatting
andyehrenberg Dec 1, 2022
c05089b
remove trailing whitespaces
andyehrenberg Dec 1, 2022
7551181
fix flax whisper docs
andyehrenberg Dec 1, 2022
153f2cb
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 1, 2022
e255a97
add generation logic to unlock flax whisper
andyehrenberg Dec 2, 2022
f8009d7
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 2, 2022
d003074
remove scans
andyehrenberg Dec 2, 2022
ba8a358
give credits to Flax Bart implementation
andyehrenberg Dec 2, 2022
9f4578d
remove unused imports
andyehrenberg Dec 2, 2022
be33fbd
add license
andyehrenberg Dec 2, 2022
8b1338b
remove assert
andyehrenberg Dec 2, 2022
c567f79
more credits to Bart
andyehrenberg Dec 2, 2022
fbe4e25
fix style
andyehrenberg Dec 2, 2022
cde5afd
formatting
andyehrenberg Dec 2, 2022
6aeb8c8
support left padding
andyehrenberg Dec 2, 2022
ec9ca19
add flax whisper generation test
andyehrenberg Dec 5, 2022
8bce923
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 6, 2022
3f902f6
remove copied from comments whenever not a full copy
andyehrenberg Dec 7, 2022
3fd0a7c
fix docstrings for logits processors
andyehrenberg Dec 7, 2022
abc14a1
revert change to FlaxForceTokensLogitsProcessor
andyehrenberg Dec 7, 2022
d784a23
revert doc changes
andyehrenberg Dec 7, 2022
3dd8282
improve generation docs
andyehrenberg Dec 7, 2022
77fce32
reorganize
andyehrenberg Dec 7, 2022
fefefde
formatting
andyehrenberg Dec 7, 2022
04ad651
cleanup docs
andyehrenberg Dec 7, 2022
14e19c0
add tests
andyehrenberg Dec 7, 2022
cf67b38
handle empty list case
andyehrenberg Dec 7, 2022
3de7509
fix forced decoder ids in flax tests
andyehrenberg Dec 8, 2022
1077588
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 9, 2022
5e2256a
add flax whisper to inits
andyehrenberg Dec 12, 2022
ada32b8
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 12, 2022
669db4e
upate dummy objects
andyehrenberg Dec 12, 2022
bea6cf0
docs for FlaxAutoModelForSpeechSeq2Seq
andyehrenberg Dec 12, 2022
e4270b4
fix decoder_position_ids computation in pretrained model decode/__cal…
andyehrenberg Dec 14, 2022
135b634
add Copied from statements as necessary
andyehrenberg Dec 15, 2022
21fe767
compute position_ids only in __call__ and decode methods of pretraine…
andyehrenberg Dec 16, 2022
a901674
improve readabilityof compute positional embeddings
andyehrenberg Dec 16, 2022
f8d4686
check dimensionality of input_features instead of hidden_states
andyehrenberg Dec 16, 2022
b407611
copied from statement for init_cache
andyehrenberg Dec 16, 2022
8e78c86
formatting
andyehrenberg Dec 16, 2022
810358c
fix copies
andyehrenberg Dec 16, 2022
b06a6ba
fix copies
andyehrenberg Dec 16, 2022
45efd60
pass attention mask to encoder layers
andyehrenberg Dec 21, 2022
718f53b
fix decoder module outputs
andyehrenberg Dec 21, 2022
07a24a8
set dtype
andyehrenberg Dec 22, 2022
43c4ed8
smaller flax model for whisper test
andyehrenberg Dec 22, 2022
ecaac58
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 22, 2022
7b35907
Update src/transformers/generation/flax_utils.py
andyehrenberg Dec 31, 2022
8a4d990
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
17c22fe
Update tests/models/whisper/test_modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
8c021ae
cleanup
andyehrenberg Dec 31, 2022
2aed9af
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
64da8fa
bias cleanup
andyehrenberg Dec 31, 2022
6fc7404
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 31, 2022
618f85b
doc fix
andyehrenberg Dec 31, 2022
8b56bf4
align style for force tokens processor
andyehrenberg Jan 2, 2023
209834d
readability
andyehrenberg Jan 3, 2023
fac30a0
fix input shape in tests
andyehrenberg Jan 3, 2023
aa87c98
revert FlaxGenerationMixin docstring
andyehrenberg Jan 3, 2023
23af05b
formatting
andyehrenberg Jan 3, 2023
b8086b6
fix tests
andyehrenberg Jan 3, 2023
acef3e0
fix imports
andyehrenberg Jan 3, 2023
da1df33
consistent encoder hidden states
andyehrenberg Jan 3, 2023
4cdba95
consistent hidden states
andyehrenberg Jan 3, 2023
dd7473b
input shapes
andyehrenberg Jan 3, 2023
c5621f7
typo
andyehrenberg Jan 3, 2023
46aec12
partial class trick
andyehrenberg Jan 3, 2023
a003616
partial class for input shape
andyehrenberg Jan 3, 2023
a9604a5
base_class with correct input shape
andyehrenberg Jan 3, 2023
5120afe
partial base classes
andyehrenberg Jan 3, 2023
c6b1ae4
match by name
andyehrenberg Jan 3, 2023
4c239fc
set main_input_name
andyehrenberg Jan 4, 2023
279ceb6
compare on names
andyehrenberg Jan 4, 2023
b81630e
Merge branch 'main' into flax_whisper
andyehrenberg Jan 9, 2023
797fab1
formatting
andyehrenberg Jan 9, 2023
f3173d8
remove unused import
andyehrenberg Jan 9, 2023
b4696ca
safer position ids computation
andyehrenberg Jan 10, 2023
1c11ca6
safer position id computation
andyehrenberg Jan 10, 2023
c128fd8
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Jan 18, 2023
2ae5b08
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Jan 18, 2023
48583bd
remove identical inherited tests
andyehrenberg Jan 18, 2023
c93232f
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Jan 18, 2023
1c18f61
fix prompt ids in tests
andyehrenberg Jan 18, 2023
c3b1d34
use generation config
andyehrenberg Jan 18, 2023
bf15d5f
use jnp array
andyehrenberg Jan 18, 2023
c5fc14b
better var names
andyehrenberg Jan 18, 2023
161cb8a
more explicit bias use
andyehrenberg Jan 18, 2023
d9cedb9
Merge branch 'main' into flax_whisper
andyehrenberg Jan 18, 2023
bb9d0af
import transformers
andyehrenberg Jan 18, 2023
f1d90d2
formatting
andyehrenberg Jan 18, 2023
733ae2b
test formatting
andyehrenberg Jan 18, 2023
6295691
remove unused imports
andyehrenberg Jan 18, 2023
902555e
remove unused imports
andyehrenberg Jan 18, 2023
cba4942
formatting
andyehrenberg Jan 18, 2023
0173945
isort
andyehrenberg Jan 18, 2023
48640e5
docs
andyehrenberg Jan 18, 2023
1daee2b
fix ln orders for encoder hidden states
andyehrenberg Jan 26, 2023
fdb0a61
Merge branch 'main' into flax_whisper
andyehrenberg Feb 3, 2023
632c4be
whisper unique generation stuff
andyehrenberg Feb 3, 2023
95403d6
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Feb 3, 2023
c5c3ac1
flake
andyehrenberg Feb 3, 2023
907905f
use finfo for attention bias
andyehrenberg Feb 3, 2023
9dbcda8
docs
andyehrenberg Feb 3, 2023
d36cd2c
Update src/transformers/generation/flax_utils.py
andyehrenberg Feb 14, 2023
ab01cfc
docs
andyehrenberg Feb 14, 2023
62d172a
add timestamp flax test
andyehrenberg Feb 14, 2023
455b8bf
jit for timestamps
andyehrenberg Feb 14, 2023
89658d0
formatting
andyehrenberg Feb 14, 2023
a75fd03
clean up timestamps processor
andyehrenberg Feb 15, 2023
758d56c
formatting
andyehrenberg Feb 15, 2023
f9ac652
remove if_true
andyehrenberg Feb 17, 2023
94a526e
cleanup
andyehrenberg Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/auto.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks.

[[autodoc]] TFAutoModelForSpeechSeq2Seq

### FlaxAutoModelForSpeechSeq2Seq

[[autodoc]] FlaxAutoModelForSpeechSeq2Seq

### AutoModelForAudioXVector

[[autodoc]] AutoModelForAudioXVector
Expand Down
11 changes: 11 additions & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] TFWhisperForConditionalGeneration
- call


## FlaxWhisperModel

[[autodoc]] FlaxWhisperModel
- __call__

## FlaxWhisperForConditionalGeneration

[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3293,6 +3293,7 @@
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
Expand All @@ -3306,6 +3307,7 @@
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForSpeechSeq2Seq",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
Expand Down Expand Up @@ -3489,6 +3491,13 @@
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
_import_structure["models.whisper"].extend(
[
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
)
_import_structure["models.xglm"].extend(
[
"FlaxXGLMForCausalLM",
Expand Down Expand Up @@ -6208,6 +6217,7 @@
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
Expand All @@ -6221,6 +6231,7 @@
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForSpeechSeq2Seq,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
Expand Down Expand Up @@ -6356,6 +6367,7 @@
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
207 changes: 206 additions & 1 deletion src/transformers/generation/flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,215 @@ def __init__(self, min_length: int, eos_token_id: int):
self.eos_token_id = eos_token_id

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:

andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)

return scores


class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
`begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
begining of the generation.

Args:
begin_suppress_tokens (`List[int]`):
Tokens to not sample.
begin_index (`int`):
Index where the tokens are suppressed.
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index

def __call__(self, input_ids, scores, cur_len: int):
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)

scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)

return scores


class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
to be `-inf` so they are not sampled.

Args:
suppress_tokens (`list`):
Tokens to not sample.
"""

def __init__(self, suppress_tokens: list):
self.suppress_tokens = list(suppress_tokens)

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))

return scores


class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
r"""
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
to `-inf` so that they are sampled at their corresponding index.

Args:
force_token_map (`list`):
Map giving token ids and indices where they will be forced to be sampled.
"""

def __init__(self, force_token_map):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
# Indexes without forced tokens will have a negative value.
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
for index, token in force_token_map.items():
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]

new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
return new_scores

scores = lax.cond(
cur_len >= self.force_token_array.shape[0],
# If the current length is geq than the length of force_token_array, the processor does nothing.
lambda: scores,
# Otherwise, it may force a certain token.
lambda: lax.cond(
self.force_token_array[cur_len] >= 0,
# Only valid (positive) tokens are forced
lambda: _force_token(cur_len),
# Otherwise, the processor does nothing.
lambda: scores,
),
)
return scores


class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.

Args:
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""

def __init__(self, generate_config, model_config, decoder_input_length):
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = decoder_input_length + 1 # len(generate_config.forced_decoder_ids) + 1
# if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
# self.begin_index -= 1
if generate_config.is_multilingual:
# room for language token and task token
self.begin_index += 2
if hasattr(generate_config, "max_initial_timestamp_index"):
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
else:
self.max_initial_timestamp_index = model_config.vocab_size
if self.max_initial_timestamp_index is None:
self.max_initial_timestamp_index = model_config.vocab_size

def __call__(self, input_ids, scores, cur_len):
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
# suppress <|notimestamps|> which is handled by without_timestamps
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
# if input_ids.shape[1] == self.begin_index:
# scores[:, self.timestamp_begin] = 0

def handle_pairs(input_ids_k, scores_k):
last_was_timestamp_1 = jax.lax.cond(
(cur_len - self.begin_index) >= 1,
lambda: True,
lambda: False,
)
last_was_timestamp_2 = jax.lax.cond(
input_ids_k[cur_len - 1] >= self.timestamp_begin,
lambda: True,
lambda: False,
)
last_was_timestamp = last_was_timestamp_1 * last_was_timestamp_2

penultimate_was_timestamp_1 = jax.lax.cond(
(cur_len - self.begin_index) < 2,
lambda: True,
lambda: False,
)
penultimate_was_timestamp_2 = jax.lax.cond(
input_ids_k[cur_len - 2] >= self.timestamp_begin,
lambda: True,
lambda: False,
)
penultimate_was_timestamp = penultimate_was_timestamp_1 + penultimate_was_timestamp_2

def if_true():
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
return jax.lax.cond(
penultimate_was_timestamp > 0,
lambda: scores_k.at[self.timestamp_begin :].set(-float("inf")),
lambda: scores_k.at[: self.eos_token_id].set(-float("inf")),
)

return jax.lax.cond(last_was_timestamp, if_true, lambda: scores_k)

scores = jax.vmap(handle_pairs)(input_ids, scores)

apply_max_initial_timestamp = jax.lax.cond(
cur_len == self.begin_index,
lambda: True,
lambda: False,
)
apply_max_initial_timestamp = jax.lax.cond(
self.max_initial_timestamp_index is not None,
lambda: True and apply_max_initial_timestamp,
lambda: False,
)

def if_true():
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
return scores.at[:, last_allowed + 1 :].set(-float("inf"))

scores = jnp.where(
apply_max_initial_timestamp,
if_true(),
scores,
)

# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = jax.nn.log_softmax(scores, axis=-1)

def handle_cumulative_probs(logprobs_k, scores_k):
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
return jnp.where(
timestamp_logprob > max_text_token_logprob,
scores_k.at[: self.timestamp_begin].set(-float("inf")),
scores_k,
)

scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)

return scores
Loading