Skip to content

Commit

Permalink
fix (#1082)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Nov 2, 2024
1 parent 0f79994 commit 29c954f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion whisper/mlx_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.4.0"
__version__ = "0.4.1"
47 changes: 23 additions & 24 deletions whisper/mlx_whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,35 +589,34 @@ def _step(inputs, audio_features, tokens, sum_logprobs):
)
return tokens, completed, sum_logprobs, pre_logits

try:
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)

for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx:
break
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)

for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs

finally:
self.inference.reset()
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs

return tokens, sum_logprobs, no_speech_probs

def run(self, mel: mx.array) -> List[DecodingResult]:
self.inference.reset()
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
Expand Down

0 comments on commit 29c954f

Please sign in to comment.