Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop python 3.7 support #889

Merged
merged 1 commit into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,10 @@ def __init__(self, temperature: float, eot: int):
self.eot = eot

def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
next_tokens = Categorical(logits=logits / self.temperature).sample()

logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
Expand Down Expand Up @@ -511,10 +510,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:

def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt

if prefix:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
Expand All @@ -523,7 +520,7 @@ def _get_initial_tokens(self) -> Tuple[int]:
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens

if prompt:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
Expand Down Expand Up @@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)

result = DecodingTask(model, options).run(mel)

if single:
result = result[0]

return result
return result[0] if single else result
38 changes: 13 additions & 25 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from functools import lru_cache, cached_property
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str:
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)

@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id

@property
@lru_cache()
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")

@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")

@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")

@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")

@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")

@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1

@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
Expand All @@ -210,8 +202,7 @@ def language_token(self) -> int:

raise KeyError(f"Language {self.language} not found in tokenizer.")

@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
Expand All @@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]:
result.append(token_id)
return tuple(result)

@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)

@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])

@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
Expand Down
25 changes: 14 additions & 11 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
**decode_options,
):
"""
Expand Down Expand Up @@ -138,10 +139,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
all_segments = []
prompt_reset_since = 0

initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []

def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
Expand Down Expand Up @@ -243,7 +245,11 @@ def add_segment(
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek

return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language
)


def cli():
Expand Down Expand Up @@ -292,21 +298,18 @@ def cli():
args["language"] = "en"

temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]

threads = args.pop("threads")
if threads > 0:
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)

from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
Expand Down