diff --git a/whisper/README.md b/whisper/README.md index ac6e95f66..cd3bc684a 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -25,7 +25,7 @@ pip install mlx-whisper At its simplest: -``` +```sh mlx_whisper audio_file.mp3 ``` @@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. +You can also pipe the audio content of other programs via stdin: + +```sh +some-process | mlx_whisper - +``` + +The default output file name will be `content.*`. You can specify the name with +the `--output-name` flag. + #### API Transcribe audio with: @@ -103,7 +112,7 @@ python convert.py --help ``` By default, the conversion script will make the directory `mlx_models` -and save the converted `weights.npz` and `config.json` there. +and save the converted `weights.npz` and `config.json` there. Each time it is run, `convert.py` will overwrite any model in the provided path. To save different models, make sure to set `--mlx-path` to a unique diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index e04309c10..c8cca07c6 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -3,7 +3,7 @@ import os from functools import lru_cache from subprocess import CalledProcessError, run -from typing import Union +from typing import Optional, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False): """ Open an audio file and read as mono waveform, resampling as necessary @@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + if from_stdin: + cmd = ["ffmpeg", "-i", "pipe:0"] + else: + cmd = ["ffmpeg", "-nostdin", "-i", file] + # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", + cmd.extend([ "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" - ] + ]) # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index c28133386..7d08a0432 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,9 +2,11 @@ import argparse import os +import pathlib import traceback import warnings +from . import audio from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .transcribe import transcribe from .writers import get_writer @@ -27,15 +29,24 @@ def str2bool(string): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "audio", nargs="+", type=str, help="Audio file(s) to transcribe" - ) + + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument( "--model", default="mlx-community/whisper-tiny", type=str, help="The model directory or hugging face repo", ) + parser.add_argument( + "--output-name", + type=str, + default=None, + help=( + "The name of transcription/translation output files before " + "--output-format extensions" + ), + ) parser.add_argument( "--output-dir", "-o", @@ -200,6 +211,7 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) writer = get_writer(output_format, output_dir) @@ -219,17 +231,25 @@ def main(): warnings.warn("--max-line-count has no effect without --max-line-width") if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - for audio_path in args.pop("audio"): + + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + + output_name = output_name or "content" + else: + output_name = output_name or pathlib.Path(audio_obj).stem try: result = transcribe( - audio_path, + audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - writer(result, audio_path, **writer_args) + writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() - print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead189..cdb35063c 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -43,15 +41,13 @@ def __init__(self, output_dir: str): self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result(