Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added --output option #333

Merged
merged 5 commits into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer

if TYPE_CHECKING:
from .model import Whisper
Expand Down Expand Up @@ -260,6 +260,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
Expand All @@ -286,6 +287,7 @@ def cli():
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)

Expand All @@ -308,22 +310,11 @@ def cli():
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)

audio_basename = os.path.basename(audio_path)

# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)

# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)

# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
writer(result, audio_path)


if __name__ == '__main__':
Expand Down
125 changes: 83 additions & 42 deletions whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import zlib
from typing import Iterator, TextIO
from typing import Callable, TextIO


def exact_div(x, y):
Expand Down Expand Up @@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"


def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)


def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.

Example usage:
from pathlib import Path
from whisper.utils import write_srt

result = transcribe(model, audio_path, temperature=temperature, **args)

# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class ResultWriter:
extension: str

def __init__(self, output_dir: str):
self.output_dir = output_dir

def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)

with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)

def write_result(self, result: dict, file: TextIO):
raise NotImplementedError


class WriteTXT(ResultWriter):
extension: str = "txt"

def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)


class WriteVTT(ResultWriter):
extension: str = "vtt"

def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


class WriteSRT(ResultWriter):
extension: str = "srt"

def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO):
json.dump(result, file)


def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"json": WriteJSON,
}

if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)

return write_all

return writers[output_format](output_dir)