Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement max line width and max line count, and make word highlighting optional #1184

Merged
merged 4 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ def cli():
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on

Expand Down Expand Up @@ -433,9 +436,17 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
writer(result, audio_path, writer_args)


if __name__ == "__main__":
Expand Down
135 changes: 94 additions & 41 deletions whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, TextIO
from typing import Callable, Optional, TextIO

system_encoding = sys.getdefaultencoding()

Expand Down Expand Up @@ -73,24 +74,24 @@ class ResultWriter:
def __init__(self, output_dir: str):
self.output_dir = output_dir

def __call__(self, result: dict, audio_path: str):
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)

with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
self.write_result(result, file=f, options=options)

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError


class WriteTXT(ResultWriter):
extension: str = "txt"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)

Expand All @@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str

def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")

if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text

yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end

if last != segment_end:
yield last, segment_end, segment_text
else:
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None

def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
if len(subtitle) > 0:
yield subtitle

if "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text

yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text

def format_timestamp(self, seconds: float):
Expand All @@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False
decimal_marker: str = "."

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result):
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)


Expand All @@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True
decimal_marker: str = ","

def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)


Expand All @@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):

extension: str = "tsv"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
Expand All @@ -180,11 +231,13 @@ def write_result(self, result: dict, file: TextIO):
class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)


def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
Expand All @@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: TextIO):
def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
writer(result, file)
writer(result, file, options)

return write_all

Expand Down