From 62241c0f718efb70703b69f2431fba463316109c Mon Sep 17 00:00:00 2001 From: Ryan Heise Date: Sun, 2 Apr 2023 11:52:24 +1000 Subject: [PATCH 1/2] Add highlight_words, max_line_width, max_line_count --- whisper/transcribe.py | 13 +++- whisper/utils.py | 143 ++++++++++++++++++++++++++++++------------ 2 files changed, 115 insertions(+), 41 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ed6d8205d..84feb12a9 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -401,6 +401,9 @@ def cli(): parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") # fmt: on @@ -433,9 +436,17 @@ def cli(): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) + word_options = ["highlight_words", "max_line_count", "max_line_width"] + if not args["word_timestamps"]: + for option in word_options: + if args[option]: + parser.error(f"--{option} requires --word_timestamps True") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - writer(result, audio_path) + writer(result, audio_path, writer_args) if __name__ == "__main__": diff --git a/whisper/utils.py b/whisper/utils.py index 490bdd19f..d7d27f8c2 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,8 +1,9 @@ import json import os +import re import sys import zlib -from typing import Callable, TextIO +from typing import Callable, TextIO, Union system_encoding = sys.getdefaultencoding() @@ -73,7 +74,7 @@ class ResultWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - def __call__(self, result: dict, audio_path: str): + def __call__(self, result: dict, audio_path: str, options: dict): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -81,16 +82,16 @@ def __call__(self, result: dict, audio_path: str): ) with open(output_path, "w", encoding="utf-8") as f: - self.write_result(result, file=f) + self.write_result(result, file=f, options=options) - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) @@ -99,33 +100,91 @@ class SubtitlesWriter(ResultWriter): always_include_hours: bool decimal_marker: str - def iterate_result(self, result: dict): - for segment in result["segments"]: - segment_start = self.format_timestamp(segment["start"]) - segment_end = self.format_timestamp(segment["end"]) - segment_text = segment["text"].strip().replace("-->", "->") - - if word_timings := segment.get("words", None): - all_words = [timing["word"] for timing in word_timings] - all_words[0] = all_words[0].strip() # remove the leading space, if any - last = segment_start - for i, this_word in enumerate(word_timings): - start = self.format_timestamp(this_word["start"]) - end = self.format_timestamp(this_word["end"]) - if last != start: - yield last, start, segment_text - - yield start, end, "".join( - [ - f"{word}" if j == i else word - for j, word in enumerate(all_words) - ] + def iterate_result(self, result: dict, options: dict): + word_timestamps: bool = "words" in result["segments"][0] + highlight_words: bool = options["highlight_words"] + max_line_width: int = options.get("max_line_width", 1000) + max_line_count: Union[int, None] = options["max_line_count"] + + def iterate_wrapped_subtitles(): + line_len = 0 + line_count = 1 + subtitle: list = [] + last_word_start = result["segments"][0]["words"][0].get("start", 0.0) + for segment in result["segments"]: + for timing in segment["words"]: + wrapped_timing = timing.copy() + long_pause = ( + max_line_count is not None + and wrapped_timing.get("start", 0.0) - last_word_start > 3.0 ) - last = end - - if last != segment_end: - yield last, segment_end, segment_text - else: + can_continue = ( + line_len + len(wrapped_timing["word"]) <= max_line_width + ) + if line_len > 0 and can_continue and not long_pause: + # continuation on same subtitle + line_len += len(wrapped_timing["word"]) + subtitle.append(wrapped_timing) + else: + line_break = "" + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_break = "\n" + line_count += 1 + wrapped_timing["word"] = ( + line_break + wrapped_timing["word"].strip() + ) + line_len = len(wrapped_timing["word"]) + subtitle.append(wrapped_timing) + last_word_start = wrapped_timing.get("start", 0.0) + if max_line_count is None: + yield subtitle + subtitle = [] + line_count = 1 + line_len = 0 + + if max_line_count is not None and len(subtitle) > 0: + yield subtitle + + if word_timestamps and (highlight_words or max_line_width): + for subtitle in iterate_wrapped_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, "".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): @@ -141,9 +200,9 @@ class WriteVTT(SubtitlesWriter): always_include_hours: bool = False decimal_marker: str = "." - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): print("WEBVTT\n", file=file) - for start, end, text in self.iterate_result(result): + for start, end, text in self.iterate_result(result, options): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) @@ -152,8 +211,10 @@ class WriteSRT(SubtitlesWriter): always_include_hours: bool = True decimal_marker: str = "," - def write_result(self, result: dict, file: TextIO): - for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): + def write_result(self, result: dict, file: TextIO, options: dict): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options), start=1 + ): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) @@ -169,7 +230,7 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -180,11 +241,13 @@ def write_result(self, result: dict, file: TextIO): class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): json.dump(result, file) -def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, @@ -196,9 +259,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO): + def write_all(result: dict, file: TextIO, options: dict): for writer in all_writers: - writer(result, file) + writer(result, file, options) return write_all From 663b2ee54eefb9c6be46a555731d4b96db0a3a52 Mon Sep 17 00:00:00 2001 From: Ryan Heise Date: Sun, 2 Apr 2023 21:03:18 +1000 Subject: [PATCH 2/2] Refactor subtitle generator --- whisper/utils.py | 64 ++++++++++++++++++++---------------------------- 1 file changed, 27 insertions(+), 37 deletions(-) diff --git a/whisper/utils.py b/whisper/utils.py index d7d27f8c2..ba5a10c41 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -3,7 +3,7 @@ import re import sys import zlib -from typing import Callable, TextIO, Union +from typing import Callable, Optional, TextIO system_encoding = sys.getdefaultencoding() @@ -101,36 +101,35 @@ class SubtitlesWriter(ResultWriter): decimal_marker: str def iterate_result(self, result: dict, options: dict): - word_timestamps: bool = "words" in result["segments"][0] + raw_max_line_width: Optional[int] = options["max_line_width"] + max_line_count: Optional[int] = options["max_line_count"] highlight_words: bool = options["highlight_words"] - max_line_width: int = options.get("max_line_width", 1000) - max_line_count: Union[int, None] = options["max_line_count"] + max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + preserve_segments = max_line_count is None or raw_max_line_width is None - def iterate_wrapped_subtitles(): + def iterate_subtitles(): line_len = 0 line_count = 1 - subtitle: list = [] - last_word_start = result["segments"][0]["words"][0].get("start", 0.0) + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: list[dict] = [] + last = result["segments"][0]["words"][0]["start"] for segment in result["segments"]: - for timing in segment["words"]: - wrapped_timing = timing.copy() - long_pause = ( - max_line_count is not None - and wrapped_timing.get("start", 0.0) - last_word_start > 3.0 - ) - can_continue = ( - line_len + len(wrapped_timing["word"]) <= max_line_width - ) - if line_len > 0 and can_continue and not long_pause: - # continuation on same subtitle - line_len += len(wrapped_timing["word"]) - subtitle.append(wrapped_timing) + for i, original_timing in enumerate(segment["words"]): + timing = original_timing.copy() + long_pause = not preserve_segments and timing["start"] - last > 3.0 + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) else: - line_break = "" + # new line + timing["word"] = timing["word"].strip() if ( len(subtitle) > 0 and max_line_count is not None and (long_pause or line_count >= max_line_count) + or seg_break ): # subtitle break yield subtitle @@ -138,25 +137,16 @@ def iterate_wrapped_subtitles(): line_count = 1 elif line_len > 0: # line break - line_break = "\n" line_count += 1 - wrapped_timing["word"] = ( - line_break + wrapped_timing["word"].strip() - ) - line_len = len(wrapped_timing["word"]) - subtitle.append(wrapped_timing) - last_word_start = wrapped_timing.get("start", 0.0) - if max_line_count is None: - yield subtitle - subtitle = [] - line_count = 1 - line_len = 0 - - if max_line_count is not None and len(subtitle) > 0: + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + if len(subtitle) > 0: yield subtitle - if word_timestamps and (highlight_words or max_line_width): - for subtitle in iterate_wrapped_subtitles(): + if "words" in result["segments"][0]: + for subtitle in iterate_subtitles(): subtitle_start = self.format_timestamp(subtitle[0]["start"]) subtitle_end = self.format_timestamp(subtitle[-1]["end"]) subtitle_text = "".join([word["word"] for word in subtitle])