Skip to content

Commit

Permalink
Whisper improvements (#1080)
Browse files Browse the repository at this point in the history
* use safetensors in whisper

* speed up decoder

* version
  • Loading branch information
awni authored Nov 1, 2024
1 parent 85ffd2c commit 8160e0c
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 65 deletions.
4 changes: 2 additions & 2 deletions whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def load_torch_weights_and_config(
)

if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu")
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
Expand Down Expand Up @@ -387,7 +387,7 @@ def quantize(weights, config, args):

# Save weights
print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights)
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)

# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
Expand Down
2 changes: 1 addition & 1 deletion whisper/mlx_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.3.0"
__version__ = "0.4.0"
134 changes: 76 additions & 58 deletions whisper/mlx_whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ def detect_language(
logits = model.logits(x, mel)[:, 0]

# collect detected languages; suppress all non-language tokens
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
mask[list(tokenizer.all_language_tokens)] = 0.0
logits += mx.array(mask)
logits += mask
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
language_token_probs = np.array(language_token_probs)
language_probs = [
{
c: language_token_probs[i, j].item()
Expand Down Expand Up @@ -129,17 +130,12 @@ class DecodingResult:


class Inference:
def __init__(self, model: "Whisper", initial_token_length: int):
def __init__(self, model: "Whisper"):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None

def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
"""Perform a forward pass on the decoder and return per-token logits"""
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]

logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
Expand Down Expand Up @@ -251,6 +247,11 @@ def finalize(
raise NotImplementedError


@mx.compile
def categorical(logits, temp):
return mx.random.categorical(logits / temp)


class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
Expand All @@ -262,10 +263,8 @@ def update(
if self.temperature == 0:
next_tokens = logits.argmax(axis=-1)
else:
next_tokens = mx.random.categorical(logits=logits / self.temperature)
next_tokens = categorical(logits, self.temperature)

next_tokens = mx.argmax(logits, axis=-1)
logits = logits.astype(mx.float32)
logprobs = logits - mx.logsumexp(logits, axis=-1)

current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
Expand All @@ -281,7 +280,7 @@ def update(
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
# make sure each sequence has at least one EOT token at the end
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
return tokens, sum_logprobs.tolist()
return tokens, sum_logprobs


class LogitFilter:
Expand Down Expand Up @@ -340,10 +339,10 @@ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
if self.tokenizer.no_timestamps is not None:
mask[:, self.tokenizer.no_timestamps] = -np.inf

# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = sampled_tokens.tolist()
## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
tokens = tokens.tolist()
for k in range(len(tokens)):
seq = tokens[k][self.sample_begin :]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
Expand All @@ -368,7 +367,7 @@ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
last_timestamp += 1
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf

if tokens.shape[1] == self.sample_begin:
if len(tokens[0]) == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
mask[:, : self.tokenizer.timestamp_begin] = -np.inf

Expand All @@ -380,16 +379,20 @@ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
mask[:, last_allowed + 1 :] = -np.inf

# if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
mask[k, : self.tokenizer.timestamp_begin] = -np.inf

return logits + mx.array(mask, logits.dtype)
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1, keepdims=True
)
max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
axis=-1, keepdims=True
)
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
timestamp_logprob > max_text_token_logprob,
-mx.inf,
mask[:, : self.tokenizer.timestamp_begin],
)
return logits + mask


class DecodingTask:
Expand Down Expand Up @@ -424,17 +427,14 @@ def __init__(self, model: "Whisper", options: DecodingOptions):
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)

# inference: implements the forward pass through the decoder, including kv caching
self.inference = Inference(model, len(self.initial_tokens))
self.inference = Inference(model)

# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)

# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
raise NotImplementedError("Beam search decoder is not yet implemented")
# self.decoder = BeamSearchDecoder(
# options.beam_size, tokenizer.eot, self.inference, options.patience
# )
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

Expand All @@ -448,6 +448,7 @@ def __init__(self, model: "Whisper", options: DecodingOptions):
self.logit_filters.append(
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
)

if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
Expand Down Expand Up @@ -570,35 +571,47 @@ def _detect_language(self, audio_features: mx.array, tokens: np.array):

def _main_loop(self, audio_features: mx.array, tokens: mx.array):
n_batch = tokens.shape[0]
sum_logprobs: mx.array = mx.zeros(n_batch)
no_speech_probs = [np.nan] * n_batch
sum_logprobs = mx.zeros(n_batch)

def _step(inputs, audio_features, tokens, sum_logprobs):
pre_logits = self.inference.logits(inputs, audio_features)

# consider the logits at the last token only
logits = pre_logits[:, -1]

# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logits = logit_filter.apply(logits, tokens)

# expand the tokens tensor with the selected next tokens
tokens, completed, sum_logprobs = self.decoder.update(
tokens, logits, sum_logprobs
)
return tokens, completed, sum_logprobs, pre_logits

try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)

if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = mx.softmax(
logits[:, self.sot_index].astype(mx.float32), axis=-1
)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

# now we need to consider the logits at the last token only
logits = logits[:, -1]

# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logits = logit_filter.apply(logits, tokens)

# expand the tokens tensor with the selected next tokens
tokens, completed, sum_logprobs = self.decoder.update(
tokens, logits, sum_logprobs
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)

for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)

mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs

finally:
self.inference.reset()

Expand All @@ -610,8 +623,8 @@ def run(self, mel: mx.array) -> List[DecodingResult]:
n_audio: int = mel.shape[0]

audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
tokens: np.array = np.array(self.initial_tokens)
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
tokens: mx.array = mx.array(self.initial_tokens)
tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))

# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
Expand All @@ -626,7 +639,6 @@ def run(self, mel: mx.array) -> List[DecodingResult]:
]

# repeat tokens by the group size, for beam search or best-of-n sampling
tokens = mx.array(tokens)
if self.n_group > 1:
tokens = tokens[:, None, :]
tokens = mx.broadcast_to(
Expand All @@ -649,7 +661,13 @@ def run(self, mel: mx.array) -> List[DecodingResult]:

# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens = tokens[..., self.sample_begin :].tolist()
tokens = tokens[..., self.sample_begin :]

# eval and convert to list
mx.eval(tokens, sum_logprobs, no_speech_probs)
tokens = tokens.tolist()
sum_logprobs = sum_logprobs.tolist()
no_speech_probs = no_speech_probs.tolist()
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]

# select the top-ranked sample in each group
Expand Down
5 changes: 4 additions & 1 deletion whisper/mlx_whisper/load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def load_model(

model_args = whisper.ModelDimensions(**config)

weights = mx.load(str(model_path / "weights.npz"))
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))

model = whisper.Whisper(model_args, dtype)

Expand Down
1 change: 1 addition & 0 deletions whisper/mlx_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def new_segment(

decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)

tokens = np.array(result.tokens)

if no_speech_threshold is not None:
Expand Down
5 changes: 2 additions & 3 deletions whisper/mlx_whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ def qkv_attention(self, q, k, v, mask=None):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)

w = mx.softmax(qk, axis=-1).astype(q.dtype)
w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk
return out, qk.astype(mx.float32)


class ResidualAttentionBlock(nn.Module):
Expand Down

0 comments on commit 8160e0c

Please sign in to comment.