diff --git a/whisper/decoding.py b/whisper/decoding.py index bb70cc024..983c898a3 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -252,11 +252,10 @@ def __init__(self, temperature: float, eot: int): self.eot = eot def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: - temperature = self.temperature - if temperature == 0: + if self.temperature == 0: next_tokens = logits.argmax(dim=-1) else: - next_tokens = Categorical(logits=logits / temperature).sample() + next_tokens = Categorical(logits=logits / self.temperature).sample() logprobs = F.log_softmax(logits.float(), dim=-1) current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] @@ -511,10 +510,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions: def _get_initial_tokens(self) -> Tuple[int]: tokens = list(self.sot_sequence) - prefix = self.options.prefix - prompt = self.options.prompt - if prefix: + if prefix := self.options.prefix: prefix_tokens = ( self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix ) @@ -523,7 +520,7 @@ def _get_initial_tokens(self) -> Tuple[int]: prefix_tokens = prefix_tokens[-max_prefix_len:] tokens = tokens + prefix_tokens - if prompt: + if prompt := self.options.prompt: prompt_tokens = ( self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt ) @@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt result: Union[DecodingResult, List[DecodingResult]] The result(s) of decoding contained in `DecodingResult` dataclass instance(s) """ - single = mel.ndim == 2 - if single: + if single := mel.ndim == 2: mel = mel.unsqueeze(0) result = DecodingTask(model, options).run(mel) - - if single: - result = result[0] - return result + return result[0] if single else result diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index a27cb359e..7b4605f3c 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from functools import lru_cache +from functools import lru_cache, cached_property from typing import List, Optional, Tuple, Union import numpy as np @@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str: outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] return "".join(outputs) - @property - @lru_cache() + @cached_property def eot(self) -> int: return self.tokenizer.eos_token_id - @property - @lru_cache() + @cached_property def sot(self) -> int: return self._get_single_token_id("<|startoftranscript|>") - @property - @lru_cache() + @cached_property def sot_lm(self) -> int: return self._get_single_token_id("<|startoflm|>") - @property - @lru_cache() + @cached_property def sot_prev(self) -> int: return self._get_single_token_id("<|startofprev|>") - @property - @lru_cache() + @cached_property def no_speech(self) -> int: return self._get_single_token_id("<|nospeech|>") - @property - @lru_cache() + @cached_property def no_timestamps(self) -> int: return self._get_single_token_id("<|notimestamps|>") - @property - @lru_cache() + @cached_property def timestamp_begin(self) -> int: return self.tokenizer.all_special_ids[-1] + 1 - @property - @lru_cache() + @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: @@ -210,8 +202,7 @@ def language_token(self) -> int: raise KeyError(f"Language {self.language} not found in tokenizer.") - @property - @lru_cache() + @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] for token, token_id in zip( @@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]: result.append(token_id) return tuple(result) - @property - @lru_cache() + @cached_property def all_language_codes(self) -> Tuple[str]: return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) - @property - @lru_cache() + @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: return tuple(list(self.sot_sequence) + [self.no_timestamps]) - @property - @lru_cache() + @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech diff --git a/whisper/transcribe.py b/whisper/transcribe.py index c5bea7b26..80bdd7965 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -26,6 +26,7 @@ def transcribe( logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, **decode_options, ): """ @@ -138,10 +139,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: all_segments = [] prompt_reset_since = 0 - initial_prompt = decode_options.pop("initial_prompt", None) or [] - if initial_prompt: - initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt) + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + else: + initial_prompt_tokens = [] def add_segment( *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult @@ -243,7 +245,11 @@ def add_segment( pbar.update(min(num_frames, seek) - previous_seek_value) previous_seek_value = seek - return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) + return dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), + segments=all_segments, + language=language + ) def cli(): @@ -292,21 +298,18 @@ def cli(): args["language"] = "en" temperature = args.pop("temperature") - temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") - if temperature_increment_on_fallback is not None: - temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) + if (increment := args.pop("temperature_increment_on_fallback")) is not None: + temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) else: temperature = [temperature] - threads = args.pop("threads") - if threads > 0: + if (threads := args.pop("threads")) > 0: torch.set_num_threads(threads) from . import load_model model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) writer(result, audio_path)