From c6571d769eb0522bde2c22f2c90a6dd0ee6db791 Mon Sep 17 00:00:00 2001 From: Aaryan YVS <51906812+Aaryan369@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:11:18 +0530 Subject: [PATCH 1/4] Added --output option --output option will help select the output files that will be generated. Corrected the logic, which wrongly shows progress bar when verbose is set to False --- whisper/transcribe.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 654f7b419..0174bcc84 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -171,7 +171,7 @@ def add_segment( num_frames = mel.shape[-1] previous_seek_value = seek - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: + with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is False) as pbar: while seek < num_frames: timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) @@ -256,7 +256,8 @@ def cli(): parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") - + + parser.add_argument("--output", type=str, default="all", choices=["none", "txt", "vtt", "srt", "all"], help="output files to generate, all(generates txt, vtt and srt), txt(generates only txt), vtt(generates txt and vtt), srt(generates txt and srt) ") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") @@ -277,6 +278,7 @@ def cli(): parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + args = parser.parse_args().__dict__ model_name: str = args.pop("model") model_dir: str = args.pop("model_dir") @@ -307,18 +309,21 @@ def cli(): result = transcribe(model, audio_path, temperature=temperature, **args) audio_basename = os.path.basename(audio_path) - + # save TXT - with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: - write_txt(result["segments"], file=txt) + if args["output"] != "none": + with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: + write_txt(result["segments"], file=txt) # save VTT - with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: - write_vtt(result["segments"], file=vtt) + if args["output"] in ["vtt","all"]: + with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: + write_vtt(result["segments"], file=vtt) # save SRT - with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: - write_srt(result["segments"], file=srt) + if args["output"] in ["srt","all"]: + with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: + write_srt(result["segments"], file=srt) if __name__ == '__main__': From fe1b4d26a992e2c86f963a5ede60310ad5dd9526 Mon Sep 17 00:00:00 2001 From: Aaryan YVS <51906812+Aaryan369@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:20:11 +0530 Subject: [PATCH 2/4] Changed output_files variable --- whisper/transcribe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0174bcc84..41c6cc7a5 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -284,6 +284,7 @@ def cli(): model_dir: str = args.pop("model_dir") output_dir: str = args.pop("output_dir") device: str = args.pop("device") + output_files: str = args.pop("output") os.makedirs(output_dir, exist_ok=True) if model_name.endswith(".en") and args["language"] not in {"en", "English"}: @@ -311,17 +312,17 @@ def cli(): audio_basename = os.path.basename(audio_path) # save TXT - if args["output"] != "none": + if output_files != "none": with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: write_txt(result["segments"], file=txt) # save VTT - if args["output"] in ["vtt","all"]: + if output_files in ["vtt","all"]: with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: write_vtt(result["segments"], file=vtt) # save SRT - if args["output"] in ["srt","all"]: + if output_files in ["srt","all"]: with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: write_srt(result["segments"], file=srt) From e27ee28a254ab189a88728e47a2b80ac17b8d514 Mon Sep 17 00:00:00 2001 From: Aaryan YVS <51906812+Aaryan369@users.noreply.github.com> Date: Tue, 18 Oct 2022 16:12:51 +0530 Subject: [PATCH 3/4] Changed back the tqdm verbose --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 41c6cc7a5..8f406b50c 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -171,7 +171,7 @@ def add_segment( num_frames = mel.shape[-1] previous_seek_value = seek - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is False) as pbar: + with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: while seek < num_frames: timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) From dfea59cba7b0406a05fbc15b2e5f80e7d8cd2f80 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sat, 21 Jan 2023 23:53:47 -0800 Subject: [PATCH 4/4] refactor output format handling --- whisper/transcribe.py | 29 +++------- whisper/utils.py | 125 ++++++++++++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 64 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index dc44a8d80..02952dfe2 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -11,7 +11,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt +from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer if TYPE_CHECKING: from .model import Whisper @@ -260,9 +260,9 @@ def cli(): parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") - - parser.add_argument("--output", type=str, default="all", choices=["none", "txt", "vtt", "srt", "all"], help="output files to generate, all(generates txt, vtt and srt), txt(generates only txt), vtt(generates txt and vtt), srt(generates txt and srt) ") + parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") @@ -283,13 +283,12 @@ def cli(): parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") - args = parser.parse_args().__dict__ model_name: str = args.pop("model") model_dir: str = args.pop("model_dir") output_dir: str = args.pop("output_dir") + output_format: str = args.pop("output_format") device: str = args.pop("device") - output_files: str = args.pop("output") os.makedirs(output_dir, exist_ok=True) if model_name.endswith(".en") and args["language"] not in {"en", "English"}: @@ -311,25 +310,11 @@ def cli(): from . import load_model model = load_model(model_name, device=device, download_root=model_dir) + writer = get_writer(output_format, output_dir) + for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - - audio_basename = os.path.basename(audio_path) - - # save TXT - if output_files != "none": - with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: - write_txt(result["segments"], file=txt) - - # save VTT - if output_files in ["vtt","all"]: - with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: - write_vtt(result["segments"], file=vtt) - - # save SRT - if output_files in ["srt","all"]: - with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: - write_srt(result["segments"], file=srt) + writer(result, audio_path) if __name__ == '__main__': diff --git a/whisper/utils.py b/whisper/utils.py index 233d3d4ff..8315e7f45 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,5 +1,7 @@ +import json +import os import zlib -from typing import Iterator, TextIO +from typing import Callable, TextIO def exact_div(x, y): @@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" -def write_txt(transcript: Iterator[dict], file: TextIO): - for segment in transcript: - print(segment['text'].strip(), file=file, flush=True) - - -def write_vtt(transcript: Iterator[dict], file: TextIO): - print("WEBVTT\n", file=file) - for segment in transcript: - print( - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) - - -def write_srt(transcript: Iterator[dict], file: TextIO): - """ - Write a transcript to a file in SRT format. - - Example usage: - from pathlib import Path - from whisper.utils import write_srt - - result = transcribe(model, audio_path, temperature=temperature, **args) - - # save SRT - audio_basename = Path(audio_path).stem - with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: - write_srt(result["segments"], file=srt) - """ - for i, segment in enumerate(transcript, start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__(self, result: dict, audio_path: str): + audio_basename = os.path.basename(audio_path) + output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f) + + def write_result(self, result: dict, file: TextIO): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result(self, result: dict, file: TextIO): + for segment in result["segments"]: + print(segment['text'].strip(), file=file, flush=True) + + +class WriteVTT(ResultWriter): + extension: str = "vtt" + + def write_result(self, result: dict, file: TextIO): + print("WEBVTT\n", file=file) + for segment in result["segments"]: + print( + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteSRT(ResultWriter): + extension: str = "srt" + + def write_result(self, result: dict, file: TextIO): + for i, segment in enumerate(result["segments"], start=1): + # write srt lines + print( + f"{i}\n" + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO): + json.dump(result, file) + + +def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "json": WriteJSON, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all(result: dict, file: TextIO): + for writer in all_writers: + writer(result, file) + + return write_all + + return writers[output_format](output_dir)