diff --git a/Dockerfile b/Dockerfile index e769468..138e1e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,4 +35,4 @@ COPY ./.env /app/.env WORKDIR /app -CMD ["uvicorn", "--reload", "--host=0.0.0.0", "--port=5001", "wordcab_transcribe.main:app"] +CMD ["uvicorn", "--host=0.0.0.0", "--port=5001", "wordcab_transcribe.main:app"] diff --git a/notebooks/async_inference.py b/notebooks/async_inference.py index cf78691..66d69fa 100644 --- a/notebooks/async_inference.py +++ b/notebooks/async_inference.py @@ -13,7 +13,6 @@ "diarization": False, # Longer processing time but speaker segment attribution "source_lang": "en", # optional, default is "en" "timestamps": "s", # optional, default is "s". Can be "s", "ms" or "hms". - "use_batch": False, # optional, default is False "internal_vad": False, # optional, default is False "word_timestamps": True, # optional, default is False } diff --git a/notebooks/youtube_inference.py b/notebooks/youtube_inference.py index 3f332fb..fd5e571 100644 --- a/notebooks/youtube_inference.py +++ b/notebooks/youtube_inference.py @@ -7,11 +7,11 @@ # params = {"url": "https://youtu.be/vAvcxeXtBz0"} # params = {"url": "https://youtu.be/pmjrj_TrOEI"} # params = {"url": "https://youtu.be/SVwLEocqK0E"} # 2h - 3 speakers -# params = {"url": "https://youtu.be/ry9SYnV3svc"} # eng sample - 2 speakers +params = {"url": "https://youtu.be/ry9SYnV3svc"} # eng sample - 2 speakers # params = {"url": "https://youtu.be/oAhVu3HvWnw"} # params = {"url": "https://youtu.be/sfQMxf9Dm8I"} # params = {"url": "https://youtu.be/uLBZf9eS4Y0"} -params = {"url": "https://youtu.be/JJbtS8CMr80"} # 4h - multiple speakers +# params = {"url": "https://youtu.be/JJbtS8CMr80"} # 4h - multiple speakers data = { "alignment": False, # Longer processing time but better timestamps @@ -19,7 +19,6 @@ "diarization": True, # Longer processing time but speaker segment attribution "source_lang": "nl", # optional, default is "en" "timestamps": "s", # optional, default is "s". Can be "s", "ms" or "hms". - "use_batch": False, # optional, default is False "internal_vad": False, # optional, default is False "word_timestamps": False, # optional, default is False } diff --git a/noxfile.py b/noxfile.py index 44f05aa..aae2e74 100644 --- a/noxfile.py +++ b/noxfile.py @@ -149,6 +149,7 @@ def tests(session: Session) -> None: "--parallel", "-m", "pytest", + "tests/", *session.posargs, ) finally: diff --git a/tests/test_models.py b/tests/test_models.py index 2381c99..a018784 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -30,6 +30,7 @@ CortexPayload, CortexUrlResponse, CortexYoutubeResponse, + ProcessTimes, Timestamps, Utterance, Word, @@ -37,6 +38,22 @@ ) +def test_process_times() -> None: + """Test the ProcessTimes model.""" + times = ProcessTimes( + total=10.0, + transcription=5.0, + diarization=None, + alignment=None, + post_processing=2.0, + ) + assert times.total == 10.0 + assert times.transcription == 5.0 + assert times.diarization is None + assert times.alignment is None + assert times.post_processing == 2.0 + + def test_timestamps() -> None: """Test the Timestamps enum.""" assert Timestamps.seconds == "s" @@ -142,11 +159,14 @@ def test_audio_request() -> None: assert request.dual_channel is True assert request.source_lang == "en" assert request.timestamps == "s" - assert request.use_batch is False assert request.vocab == [] assert request.word_timestamps is False assert request.internal_vad is False assert request.repetition_penalty == 1.2 + assert request.compression_ratio_threshold == 2.4 + assert request.log_prob_threshold == -1.0 + assert request.no_speech_threshold == 0.6 + assert request.condition_on_previous_text is True def test_audio_response() -> None: @@ -160,11 +180,19 @@ def test_audio_response() -> None: dual_channel=False, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=False, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + post_processing=2.0, + ), ) assert response.utterances == [] assert response.audio_duration == 0.0 @@ -174,11 +202,17 @@ def test_audio_response() -> None: assert response.dual_channel is False assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is False assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, transcription=5.0, post_processing=2.0 + ) response = AudioResponse( utterances=[ @@ -204,11 +238,19 @@ def test_audio_response() -> None: dual_channel=True, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=True, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + post_processing=2.0, + ), ) assert response.utterances == [ Utterance( @@ -233,11 +275,17 @@ def test_audio_response() -> None: assert response.dual_channel is True assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is True assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, transcription=5.0, post_processing=2.0 + ) def test_base_request_valid() -> None: @@ -261,10 +309,13 @@ def test_base_request_default() -> None: assert req.diarization is False assert req.source_lang == "en" assert req.timestamps == "s" - assert req.use_batch is False assert req.word_timestamps is False assert req.internal_vad is False assert req.repetition_penalty == 1.2 + assert req.compression_ratio_threshold == 2.4 + assert req.log_prob_threshold == -1.0 + assert req.no_speech_threshold == 0.6 + assert req.condition_on_previous_text is True def test_base_request_invalid() -> None: @@ -298,11 +349,21 @@ def test_base_response() -> None: diarization=False, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=False, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ), ) assert response.utterances == [ Utterance( @@ -326,11 +387,21 @@ def test_base_response() -> None: assert response.diarization is False assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is False assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ) def test_cortex_error() -> None: @@ -353,7 +424,6 @@ def test_cortex_payload() -> None: dual_channel=False, source_lang="en", timestamps="s", - use_batch=False, word_timestamps=False, internal_vad=False, repetition_penalty=1.2, @@ -369,11 +439,14 @@ def test_cortex_payload() -> None: assert payload.dual_channel is False assert payload.source_lang == "en" assert payload.timestamps == "s" - assert payload.use_batch is False assert payload.vocab == [] assert payload.word_timestamps is False assert payload.internal_vad is False assert payload.repetition_penalty == 1.2 + assert payload.compression_ratio_threshold == 2.4 + assert payload.log_prob_threshold == -1.0 + assert payload.no_speech_threshold == 0.6 + assert payload.condition_on_previous_text is True assert payload.job_name == "test_job" assert payload.ping is False @@ -403,11 +476,21 @@ def test_cortex_url_response() -> None: diarization=False, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=False, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ), dual_channel=False, job_name="test_job", request_id="test_request_id", @@ -434,11 +517,21 @@ def test_cortex_url_response() -> None: assert response.diarization is False assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is False assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ) assert response.dual_channel is False assert response.job_name == "test_job" assert response.request_id == "test_request_id" @@ -469,11 +562,21 @@ def test_cortex_youtube_response() -> None: diarization=False, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=False, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ), video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ", job_name="test_job", request_id="test_request_id", @@ -500,11 +603,21 @@ def test_cortex_youtube_response() -> None: assert response.diarization is False assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is False assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ) assert response.video_url == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" assert response.job_name == "test_job" assert response.request_id == "test_request_id" @@ -535,11 +648,21 @@ def test_youtube_response() -> None: diarization=False, source_lang="en", timestamps="s", - use_batch=False, vocab=["custom company", "custom product"], word_timestamps=False, internal_vad=False, repetition_penalty=1.2, + compression_ratio_threshold=1.8, + log_prob_threshold=-1.0, + no_speech_threshold=0.4, + condition_on_previous_text=False, + process_times=ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ), video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ", ) assert response.utterances == [ @@ -564,9 +687,19 @@ def test_youtube_response() -> None: assert response.diarization is False assert response.source_lang == "en" assert response.timestamps == "s" - assert response.use_batch is False assert response.vocab == ["custom company", "custom product"] assert response.word_timestamps is False assert response.internal_vad is False assert response.repetition_penalty == 1.2 + assert response.compression_ratio_threshold == 1.8 + assert response.log_prob_threshold == -1.0 + assert response.no_speech_threshold == 0.4 + assert response.condition_on_previous_text is False + assert response.process_times == ProcessTimes( + total=10.0, + transcription=5.0, + diarization=2.0, + alignment=2.0, + post_processing=1.0, + ) assert response.video_url == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" diff --git a/wordcab_transcribe/config.py b/wordcab_transcribe/config.py index 6420797..5ec4076 100644 --- a/wordcab_transcribe/config.py +++ b/wordcab_transcribe/config.py @@ -19,16 +19,22 @@ # and limitations under the License. """Configuration module of the Wordcab Transcribe.""" +import asyncio +from contextlib import asynccontextmanager from os import getenv from pathlib import Path from typing import Dict, List from dotenv import load_dotenv +from fastapi import FastAPI from faster_whisper.utils import _MODELS from loguru import logger from pydantic import field_validator from pydantic.dataclasses import dataclass +from wordcab_transcribe.services.asr_service import ASRAsyncService, ASRLiveService +from wordcab_transcribe.utils import download_model, retrieve_user_platform + @dataclass class Settings: @@ -283,3 +289,55 @@ def __post_init__(self): svix_api_key=getenv("SVIX_API_KEY", ""), svix_app_id=getenv("SVIX_APP_ID", ""), ) + +# Define the maximum number of files to pre-download for the async ASR service +download_limit = asyncio.Semaphore(10) + +# Define the ASR service to use depending on the settings +if settings.asr_type == "live": + asr = ASRLiveService() +elif settings.asr_type == "async": + asr = ASRAsyncService( + whisper_model=settings.whisper_model, + compute_type=settings.compute_type, + window_lengths=settings.window_lengths, + shift_lengths=settings.shift_lengths, + multiscale_weights=settings.multiscale_weights, + extra_languages=settings.extra_languages, + extra_languages_model_paths=settings.extra_languages_model_paths, + debug_mode=settings.debug, + ) +else: + raise ValueError(f"Invalid ASR type: {settings.asr_type}") + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> None: + """Context manager to handle the startup and shutdown of the application.""" + if retrieve_user_platform() != "linux": + logger.warning( + "You are not running the application on Linux.\n" + "The application was tested on Ubuntu 22.04, so we cannot guarantee that it will work on other OS.\n" + "Report any issues with your env specs to: https://github.com/Wordcab/wordcab-transcribe/issues" + ) + + if settings.extra_languages: + logger.info("Downloading models for extra languages...") + for model in settings.extra_languages: + try: + model_path = download_model( + compute_type=settings.compute_type, language=model + ) + + if model_path is not None: + settings.extra_languages_model_paths[model] = model_path + else: + raise Exception(f"Coudn't download model for {model}") + + except Exception as e: + logger.error(f"Error downloading model for {model}: {e}") + + logger.info("Warmup initialization...") + await asr.inference_warmup() + + yield # This is where the execution of the application starts diff --git a/wordcab_transcribe/dependencies.py b/wordcab_transcribe/dependencies.py deleted file mode 100644 index 0998fdc..0000000 --- a/wordcab_transcribe/dependencies.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2023 The Wordcab Team. All rights reserved. -# -# Licensed under the Wordcab Transcribe License 0.1 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE -# -# Except as expressly provided otherwise herein, and to the fullest -# extent permitted by law, Licensor provides the Software (and each -# Contributor provides its Contributions) AS IS, and Licensor -# disclaims all warranties or guarantees of any kind, express or -# implied, whether arising under any law or from any usage in trade, -# or otherwise including but not limited to the implied warranties -# of merchantability, non-infringement, quiet enjoyment, fitness -# for a particular purpose, or otherwise. -# -# See the License for the specific language governing permissions -# and limitations under the License. -"""Dependencies for the API.""" - -import asyncio - -from wordcab_transcribe.config import settings -from wordcab_transcribe.services.asr_service import ASRAsyncService, ASRLiveService - - -# Define the ASR service to use depending on the settings -if settings.asr_type == "live": - asr = ASRLiveService() -elif settings.asr_type == "async": - asr = ASRAsyncService() -else: - raise ValueError(f"Invalid ASR type: {settings.asr_type}") - - -# Define the maximum number of files to pre-download for the async ASR service -download_limit = asyncio.Semaphore(10) diff --git a/wordcab_transcribe/logging.py b/wordcab_transcribe/logging.py index 22269c6..881a2b0 100644 --- a/wordcab_transcribe/logging.py +++ b/wordcab_transcribe/logging.py @@ -20,11 +20,10 @@ """Logging module to add a logging middleware to the Wordcab Transcribe API.""" -import asyncio import sys import time -from functools import wraps -from typing import Awaitable, Callable +import uuid +from typing import Any, Awaitable, Callable, Tuple from loguru import logger from starlette.middleware.base import BaseHTTPMiddleware @@ -32,21 +31,19 @@ from starlette.responses import Response from starlette.types import ASGIApp -from wordcab_transcribe.config import settings - class LoggingMiddleware(BaseHTTPMiddleware): """Middleware to log requests, responses, errors and execution time.""" - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, debug_mode: bool) -> None: """Initialize the middleware.""" super().__init__(app) logger.remove() logger.add( sys.stdout, level="DEBUG" - if settings.debug - else "WARNING", # Avoid logging debug messages in prod + if debug_mode + else "INFO", # Avoid logging debug messages in prod ) async def dispatch( @@ -63,56 +60,42 @@ async def dispatch( The response from the next middleware. """ start_time = time.time() - logger.debug(f"Request: {request.method} {request.url}") + tracing_id = uuid.uuid4() + + if request.method == "POST": + logger.info(f"Task [{tracing_id}] | {request.method} {request.url}") + else: + logger.info(f"{request.method} {request.url}") response = await call_next(request) process_time = time.time() - start_time - logger.debug( - f"Response status: {response.status_code}, Process Time: {process_time:.4f} secs" + logger.info( + f"Task [{tracing_id}] | Status: {response.status_code}, Time: {process_time:.4f} secs" ) return response -def time_and_tell(func: Callable) -> Callable: +def time_and_tell( + func: Callable, func_name: str, debug_mode: bool +) -> Tuple[Any, float]: """ This decorator logs the execution time of a function only if the debug setting is True. Args: - func: The function to decorate. + func: The function to call in the wrapper. + func_name: The name of the function for logging purposes. + debug_mode: The debug setting for logging purposes. Returns: The appropriate wrapper for the function. """ + start_time = time.time() + result = func() + process_time = time.time() - start_time - @wraps(func) - def sync_wrapper(*args, **kwargs) -> Callable: - """Sync wrapper for the decorated function.""" - if settings.debug: - start_time = time.time() - - result = func(*args, **kwargs) - - process_time = time.time() - start_time - logger.debug(f"{func.__name__} executed in {process_time:.4f} secs") - else: - result = func(*args, **kwargs) - - return result - - async def async_wrapper(*args, **kwargs) -> Awaitable: - """Async wrapper for the decorated function.""" - if settings.debug: - start_time = time.time() - - result = await func(*args, **kwargs) - - process_time = time.time() - start_time - logger.debug(f"{func.__name__} executed in {process_time:.4f} secs") - else: - result = await func(*args, **kwargs) - - return result + if debug_mode: + logger.debug(f"{func_name} executed in {process_time:.4f} secs") - return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + return result, process_time diff --git a/wordcab_transcribe/main.py b/wordcab_transcribe/main.py index 5fbfc9b..b4fbabe 100644 --- a/wordcab_transcribe/main.py +++ b/wordcab_transcribe/main.py @@ -23,10 +23,8 @@ from fastapi import Depends, FastAPI from fastapi import status as http_status from fastapi.responses import HTMLResponse -from loguru import logger -from wordcab_transcribe.config import settings -from wordcab_transcribe.dependencies import asr +from wordcab_transcribe.config import lifespan, settings from wordcab_transcribe.logging import LoggingMiddleware from wordcab_transcribe.router.authentication import get_current_user from wordcab_transcribe.router.v1.endpoints import ( @@ -34,7 +32,6 @@ auth_router, cortex_router, ) -from wordcab_transcribe.utils import download_model, retrieve_user_platform # Main application instance creation @@ -43,10 +40,11 @@ version=settings.version, openapi_url=f"{settings.api_prefix}/openapi.json", debug=settings.debug, + lifespan=lifespan, ) # Add logging middleware -app.add_middleware(LoggingMiddleware) +app.add_middleware(LoggingMiddleware, debug_mode=settings.debug) # Include the appropiate routers based on the settings if settings.debug is False: @@ -61,40 +59,6 @@ app.include_router(cortex_router, tags=["cortex"]) -@app.on_event("startup") -async def startup_event(): - """Startup event handler.""" - logger.debug("Starting up...") - - if retrieve_user_platform() != "linux": - logger.warning( - "You are not running the application on Linux.\n" - "The application was tested on Ubuntu 22.04, so we cannot guarantee that it will work on other OS.\n" - "Report any issues with your env specs to: https://github.com/Wordcab/wordcab-transcribe/issues" - ) - - if settings.extra_languages: - logger.info("Downloading models for extra languages...") - for model in settings.extra_languages: - try: - model_path = download_model( - compute_type=settings.compute_type, language=model - ) - - if model_path is not None: - settings.extra_languages_model_paths[model] = model_path - else: - raise Exception(f"Coudn't download model for {model}") - - except Exception as e: - logger.error(f"Error downloading model for {model}: {e}") - - logger.info("Warmup initialization...") - await asr.inference_warmup() - - logger.info("Application started!") - - @app.get("/", tags=["status"]) async def home() -> HTMLResponse: """Root endpoint returning a simple HTML page with the project info.""" diff --git a/wordcab_transcribe/models.py b/wordcab_transcribe/models.py index f24802d..7f68122 100644 --- a/wordcab_transcribe/models.py +++ b/wordcab_transcribe/models.py @@ -20,11 +20,21 @@ """Models module of the Wordcab Transcribe.""" from enum import Enum -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union from pydantic import BaseModel, field_validator +class ProcessTimes(BaseModel): + """The execution times of the different processes.""" + + total: float + transcription: float + diarization: Union[float, None] + alignment: Union[float, None] + post_processing: float + + class Timestamps(str, Enum): """Timestamps enum for the API.""" @@ -62,11 +72,15 @@ class BaseResponse(BaseModel): diarization: bool source_lang: str timestamps: str - use_batch: bool vocab: List[str] word_timestamps: bool internal_vad: bool repetition_penalty: float + compression_ratio_threshold: float + log_prob_threshold: float + no_speech_threshold: float + condition_on_previous_text: bool + process_times: ProcessTimes class AudioResponse(BaseResponse): @@ -99,7 +113,6 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -108,6 +121,17 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, + "process_times": { + "total": 2.678, + "transcription": 2.439, + "diarization": None, + "alignment": None, + "post_processing": 0.239, + }, "dual_channel": False, } } @@ -143,7 +167,6 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -152,6 +175,17 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, + "process_times": { + "total": 2.678, + "transcription": 2.439, + "diarization": None, + "alignment": None, + "post_processing": 0.239, + }, "video_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", } } @@ -184,11 +218,14 @@ class CortexPayload(BaseModel): dual_channel: Optional[bool] = False source_lang: Optional[str] = "en" timestamps: Optional[Timestamps] = Timestamps.seconds - use_batch: Optional[bool] = False vocab: Optional[List[str]] = [] word_timestamps: Optional[bool] = False internal_vad: Optional[bool] = False repetition_penalty: Optional[float] = 1.2 + compression_ratio_threshold: Optional[float] = 2.4 + log_prob_threshold: Optional[float] = -1.0 + no_speech_threshold: Optional[float] = 0.6 + condition_on_previous_text: Optional[bool] = True job_name: Optional[str] = None ping: Optional[bool] = False @@ -206,7 +243,6 @@ class Config: "dual_channel": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -215,6 +251,10 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, "job_name": "job_abc123", "ping": False, } @@ -252,7 +292,6 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -261,6 +300,17 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, + "process_times": { + "total": 2.678, + "transcription": 2.439, + "diarization": None, + "alignment": None, + "post_processing": 0.239, + }, "dual_channel": False, "job_name": "job_name", "request_id": "request_id", @@ -299,7 +349,6 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -308,6 +357,17 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, + "process_times": { + "total": 2.678, + "transcription": 2.439, + "diarization": None, + "alignment": None, + "post_processing": 0.239, + }, "video_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", "job_name": "job_name", "request_id": "request_id", @@ -323,11 +383,14 @@ class BaseRequest(BaseModel): diarization: bool = False source_lang: str = "en" timestamps: Timestamps = Timestamps.seconds - use_batch: bool = False vocab: List[str] = [] word_timestamps: bool = False internal_vad: bool = False repetition_penalty: float = 1.2 + compression_ratio_threshold: float = 2.4 + log_prob_threshold: float = -1.0 + no_speech_threshold: float = 0.6 + condition_on_previous_text: bool = True @field_validator("vocab") def validate_each_vocab_value( @@ -349,7 +412,6 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", @@ -358,6 +420,10 @@ class Config: "word_timestamps": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, } } @@ -377,16 +443,19 @@ class Config: "diarization": False, "source_lang": "en", "timestamps": "s", - "use_batch": False, "vocab": [ "custom company name", "custom product name", "custom co-worker name", ], "word_timestamps": False, - "dual_channel": False, "internal_vad": False, "repetition_penalty": 1.2, + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": True, + "dual_channel": False, } } diff --git a/wordcab_transcribe/router/v1/audio_file_endpoint.py b/wordcab_transcribe/router/v1/audio_file_endpoint.py index 3b04cf7..3455ba0 100644 --- a/wordcab_transcribe/router/v1/audio_file_endpoint.py +++ b/wordcab_transcribe/router/v1/audio_file_endpoint.py @@ -27,7 +27,7 @@ from fastapi import status as http_status from loguru import logger -from wordcab_transcribe.dependencies import asr +from wordcab_transcribe.config import asr from wordcab_transcribe.models import AudioRequest, AudioResponse from wordcab_transcribe.utils import ( convert_file_to_wav, @@ -51,11 +51,14 @@ async def inference_with_audio( # noqa: C901 dual_channel: bool = Form(False), # noqa: B008 source_lang: str = Form("en"), # noqa: B008 timestamps: str = Form("s"), # noqa: B008 - use_batch: bool = Form(False), # noqa: B008 vocab: List[str] = Form([]), # noqa: B008 word_timestamps: bool = Form(False), # noqa: B008 internal_vad: bool = Form(False), # noqa: B008 repetition_penalty: float = Form(1.2), # noqa: B008 + compression_ratio_threshold: float = Form(2.4), # noqa: B008 + log_prob_threshold: float = Form(-1.0), # noqa: B008 + no_speech_threshold: float = Form(0.6), # noqa: B008 + condition_on_previous_text: bool = Form(True), # noqa: B008 file: UploadFile = File(...), # noqa: B008 ) -> AudioResponse: """Inference endpoint with audio file.""" @@ -74,11 +77,14 @@ async def inference_with_audio( # noqa: C901 diarization=diarization, source_lang=source_lang, timestamps=timestamps, - use_batch=use_batch, vocab=vocab, word_timestamps=word_timestamps, internal_vad=internal_vad, repetition_penalty=repetition_penalty, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, dual_channel=dual_channel, ) @@ -110,11 +116,14 @@ async def inference_with_audio( # noqa: C901 dual_channel=data.dual_channel, source_lang=data.source_lang, timestamps_format=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, ) ) result = await task @@ -128,7 +137,7 @@ async def inference_with_audio( # noqa: C901 detail=str(result), ) else: - utterances, audio_duration = result + utterances, process_times, audio_duration = result return AudioResponse( utterances=utterances, audio_duration=audio_duration, @@ -138,9 +147,13 @@ async def inference_with_audio( # noqa: C901 dual_channel=data.dual_channel, source_lang=data.source_lang, timestamps=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, + process_times=process_times, ) diff --git a/wordcab_transcribe/router/v1/audio_url_endpoint.py b/wordcab_transcribe/router/v1/audio_url_endpoint.py index 0c2ed87..27d96d0 100644 --- a/wordcab_transcribe/router/v1/audio_url_endpoint.py +++ b/wordcab_transcribe/router/v1/audio_url_endpoint.py @@ -27,7 +27,7 @@ from fastapi import status as http_status from loguru import logger -from wordcab_transcribe.dependencies import asr, download_limit +from wordcab_transcribe.config import asr, download_limit from wordcab_transcribe.models import AudioRequest, AudioResponse from wordcab_transcribe.utils import ( convert_file_to_wav, @@ -82,11 +82,14 @@ async def inference_with_audio_url( dual_channel=data.dual_channel, source_lang=data.source_lang, timestamps_format=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, ) ) result = await task @@ -100,7 +103,7 @@ async def inference_with_audio_url( detail=str(result), ) else: - utterances, audio_duration = result + utterances, process_times, audio_duration = result return AudioResponse( utterances=utterances, audio_duration=audio_duration, @@ -110,9 +113,13 @@ async def inference_with_audio_url( dual_channel=data.dual_channel, source_lang=data.source_lang, timestamps=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, + process_times=process_times, ) diff --git a/wordcab_transcribe/router/v1/cortex_endpoint.py b/wordcab_transcribe/router/v1/cortex_endpoint.py index 945c4c1..1fb1119 100644 --- a/wordcab_transcribe/router/v1/cortex_endpoint.py +++ b/wordcab_transcribe/router/v1/cortex_endpoint.py @@ -76,11 +76,14 @@ async def run_cortex( dual_channel=payload.dual_channel, source_lang=payload.source_lang, timestamps=payload.timestamps, - use_batch=payload.use_batch, vocab=payload.vocab, word_timestamps=payload.word_timestamps, internal_vad=payload.internal_vad, repetition_penalty=payload.repetition_penalty, + compression_ratio_threshold=payload.compression_ratio_threshold, + log_prob_threshold=payload.log_prob_threshold, + no_speech_threshold=payload.no_speech_threshold, + condition_on_previous_text=payload.condition_on_previous_text, ) response: AudioResponse = await inference_with_audio_url( background_tasks=BackgroundTasks(), @@ -95,11 +98,14 @@ async def run_cortex( diarization=payload.diarization, source_lang=payload.source_lang, timestamps=payload.timestamps, - use_batch=payload.use_batch, vocab=payload.vocab, word_timestamps=payload.word_timestamps, internal_vad=payload.internal_vad, repetition_penalty=payload.repetition_penalty, + compression_ratio_threshold=payload.compression_ratio_threshold, + log_prob_threshold=payload.log_prob_threshold, + no_speech_threshold=payload.no_speech_threshold, + condition_on_previous_text=payload.condition_on_previous_text, ) response: YouTubeResponse = await inference_with_youtube( background_tasks=BackgroundTasks(), diff --git a/wordcab_transcribe/router/v1/youtube_endpoint.py b/wordcab_transcribe/router/v1/youtube_endpoint.py index 731ec93..8fc29df 100644 --- a/wordcab_transcribe/router/v1/youtube_endpoint.py +++ b/wordcab_transcribe/router/v1/youtube_endpoint.py @@ -27,7 +27,7 @@ from fastapi import status as http_status from loguru import logger -from wordcab_transcribe.dependencies import asr, download_limit +from wordcab_transcribe.config import asr, download_limit from wordcab_transcribe.models import BaseRequest, YouTubeResponse from wordcab_transcribe.utils import delete_file, download_audio_file @@ -58,11 +58,14 @@ async def inference_with_youtube( dual_channel=False, source_lang=data.source_lang, timestamps_format=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, ) ) result = await task @@ -76,7 +79,7 @@ async def inference_with_youtube( detail=str(result), ) else: - utterances, audio_duration = result + utterances, process_times, audio_duration = result return YouTubeResponse( utterances=utterances, audio_duration=audio_duration, @@ -85,10 +88,14 @@ async def inference_with_youtube( diarization=data.diarization, source_lang=data.source_lang, timestamps=data.timestamps, - use_batch=data.use_batch, vocab=data.vocab, word_timestamps=data.word_timestamps, internal_vad=data.internal_vad, repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, + process_times=process_times, video_url=url, ) diff --git a/wordcab_transcribe/services/align_service.py b/wordcab_transcribe/services/align_service.py index 55996f3..74ffb91 100644 --- a/wordcab_transcribe/services/align_service.py +++ b/wordcab_transcribe/services/align_service.py @@ -31,7 +31,6 @@ from loguru import logger from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor -from wordcab_transcribe.logging import time_and_tell from wordcab_transcribe.utils import interpolate_nans @@ -151,7 +150,6 @@ def __init__(self, device: str) -> None: self.model_map = MODEL_MAPPING self.available_lang = self.model_map.keys() - @time_and_tell def __call__( self, filepath: str, diff --git a/wordcab_transcribe/services/asr_service.py b/wordcab_transcribe/services/asr_service.py index 2c66e35..e930513 100644 --- a/wordcab_transcribe/services/asr_service.py +++ b/wordcab_transcribe/services/asr_service.py @@ -22,14 +22,14 @@ import asyncio import functools import os +import time import traceback from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from loguru import logger -from wordcab_transcribe.config import settings from wordcab_transcribe.logging import time_and_tell from wordcab_transcribe.services.align_service import AlignService from wordcab_transcribe.services.diarization.diarize_service import DiarizeService @@ -75,8 +75,30 @@ async def process_input(self) -> None: class ASRAsyncService(ASRService): """ASR Service module for async endpoints.""" - def __init__(self) -> None: - """Initialize the ASRAsyncService class.""" + def __init__( + self, + whisper_model: str, + compute_type: str, + window_lengths: List[int], + shift_lengths: List[int], + multiscale_weights: List[float], + extra_languages: List[str], + extra_languages_model_paths: List[str], + debug_mode: bool, + ) -> None: + """ + Initialize the ASRAsyncService class. + + Args: + whisper_model (str): The path to the whisper model. + compute_type (str): The compute type to use for inference. + window_lengths (List[int]): The window lengths to use for diarization. + shift_lengths (List[int]): The shift lengths to use for diarization. + multiscale_weights (List[float]): The multiscale weights to use for diarization. + extra_languages (List[str]): The list of extra languages to support. + extra_languages_model_paths (List[str]): The list of paths to the extra language models. + debug_mode (bool): Whether to run in debug mode. + """ super().__init__() if self.num_gpus > 1 and self.device == "cuda": @@ -88,17 +110,19 @@ def __init__(self) -> None: self.services: dict = { "transcription": TranscribeService( - model_path=settings.whisper_model, - compute_type=settings.compute_type, + model_path=whisper_model, + compute_type=compute_type, device=self.device, device_index=device_index, + extra_languages=extra_languages, + extra_languages_model_paths=extra_languages_model_paths, ), "diarization": DiarizeService( device=self.device, device_index=device_index, - window_lengths=settings.window_lengths, - shift_lengths=settings.shift_lengths, - multiscale_weights=settings.multiscale_weights, + window_lengths=window_lengths, + shift_lengths=shift_lengths, + multiscale_weights=multiscale_weights, ), "alignment": AlignService(self.device), "post_processing": PostProcessingService(), @@ -113,6 +137,8 @@ def __init__(self) -> None: "temperature": 0.0, } + self.debug_mode = debug_mode + async def inference_warmup(self) -> None: """Warmup the GPU by loading the models.""" for gpu_index in self.gpu_handler.device_index: @@ -125,14 +151,16 @@ async def inference_warmup(self) -> None: dual_channel=False, source_lang="en", timestamps_format="s", - use_batch=False, vocab=[], word_timestamps=False, internal_vad=False, repetition_penalty=1.0, + compression_ratio_threshold=2.4, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + condition_on_previous_text=True, ) - @time_and_tell async def process_input( self, filepath: Union[str, Tuple[str, str]], @@ -142,12 +170,15 @@ async def process_input( dual_channel: bool, source_lang: str, timestamps_format: str, - use_batch: bool, vocab: List[str], word_timestamps: bool, internal_vad: bool, repetition_penalty: float, - ) -> Union[Tuple[List[dict], float], Exception]: + compression_ratio_threshold: float, + log_prob_threshold: float, + no_speech_threshold: float, + condition_on_previous_text: bool, + ) -> Union[Tuple[List[dict], Dict[str, float], float], Exception]: """Process the input request and return the results. This method will create a task and add it to the appropriate queues. @@ -157,22 +188,47 @@ async def process_input( and stored in separated keys in the task dictionary. Args: - filepath (Union[str, Tuple[str, str]]): Path to the audio file or tuple of paths to the audio files. - alignment (bool): Whether to do alignment or not. - num_speakers (int): The number of oracle speakers. - diarization (bool): Whether to do diarization or not. - dual_channel (bool): Whether to do dual channel or not. - source_lang (str): Source language of the audio file. - timestamps_format (str): Timestamps format to use. - use_batch (bool): Whether to use batch processing or not. - vocab (List[str]): List of words to use for the vocabulary. - word_timestamps (bool): Whether to return word timestamps or not. - internal_vad (bool): Whether to use faster-whisper's VAD or not. - repetition_penalty (float): The repetition penalty to use for the beam search. + filepath (Union[str, Tuple[str, str]]): + Path to the audio file or tuple of paths to the audio files. + alignment (bool): + Whether to do alignment or not. + num_speakers (int): + The number of oracle speakers. + diarization (bool): + Whether to do diarization or not. + dual_channel (bool): + Whether to do dual channel or not. + source_lang (str): + Source language of the audio file. + timestamps_format (str): + Timestamps format to use. + vocab (List[str]): + List of words to use for the vocabulary. + word_timestamps (bool): + Whether to return word timestamps or not. + internal_vad (bool): + Whether to use faster-whisper's VAD or not. + repetition_penalty (float): + The repetition penalty to use for the beam search. + compression_ratio_threshold (float): + If the gzip compression ratio is above this value, treat as failed. + log_prob_threshold (float): + If the average log probability over sampled tokens is below this value, treat as failed. + no_speech_threshold (float): + If the no_speech probability is higher than this value AND the average log probability + over sampled tokens is below `log_prob_threshold`, consider the segment as silent. + condition_on_previous_text (bool): + If True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone + to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. Returns: - Union[Tuple[List[dict], float], Exception]: The final transcription result associated with the audio - duration or an exception. + Union[Tuple[List[dict], Dict[str, float], float], Exception]: + The results of the ASR pipeline or an exception if something went wrong. + Results are returned as a tuple of the following: + * List[dict]: The final results of the ASR pipeline. + * Dict[str, float]: the process times for each step. + * float: The audio duration """ if isinstance(filepath, tuple): audio, duration = [], [] @@ -197,11 +253,14 @@ async def process_input( "dual_channel": dual_channel, "source_lang": source_lang, "timestamps_format": timestamps_format, - "use_batch": use_batch, "vocab": vocab, "word_timestamps": word_timestamps, "internal_vad": internal_vad, "repetition_penalty": repetition_penalty, + "compression_ratio_threshold": compression_ratio_threshold, + "log_prob_threshold": log_prob_threshold, + "no_speech_threshold": no_speech_threshold, + "condition_on_previous_text": condition_on_previous_text, "transcription_result": None, "transcription_done": asyncio.Event(), "diarization_result": None, @@ -210,22 +269,31 @@ async def process_input( "alignment_done": asyncio.Event(), "post_processing_result": None, "post_processing_done": asyncio.Event(), - "time": asyncio.get_event_loop().time(), + "process_times": {}, } # Pick the first available GPU for the task gpu_index = await self.gpu_handler.get_device() if self.device == "cuda" else 0 logger.info(f"Using GPU {gpu_index} for the task") + start_process_time = time.time() + asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.process_transcription, task, gpu_index) + None, + functools.partial( + self.process_transcription, task, gpu_index, self.debug_mode + ), ) if diarization and dual_channel is False: asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.process_diarization, task, gpu_index) + None, + functools.partial( + self.process_diarization, task, gpu_index, self.debug_mode + ), ) else: + task["process_times"]["diarization"] = None task["diarization_done"].set() await task["transcription_done"].wait() @@ -241,9 +309,13 @@ async def process_input( else: if alignment and dual_channel is False: asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.process_alignment, task, gpu_index) + None, + functools.partial( + self.process_alignment, task, gpu_index, self.debug_mode + ), ) else: + task["process_times"]["alignment"] = None task["alignment_done"].set() await task["alignment_done"].wait() @@ -266,108 +338,132 @@ async def process_input( return task["post_processing_result"] result = task.pop("post_processing_result") + process_times: Dict[str, float] = task.pop("process_times") + process_times["total"] = time.time() - start_process_time + del task # Delete the task to free up memory - return result, duration + return result, process_times, duration - @time_and_tell - def process_transcription(self, task: dict, gpu_index: int) -> None: + def process_transcription( + self, task: dict, gpu_index: int, debug_mode: bool + ) -> None: """ Process a task of transcription and update the task with the result. Args: task (dict): The task and its parameters. gpu_index (int): The GPU index to use for the transcription. + debug_mode (bool): Whether to run in debug mode or not. Returns: None: The task is updated with the result. """ try: - segments = self.services["transcription"]( - task["input"], - source_lang=task["source_lang"], - model_index=gpu_index, - suppress_blank=False, - vocab=None if task["vocab"] == [] else task["vocab"], - word_timestamps=True, - internal_vad=task["internal_vad"], - repetition_penalty=task["repetition_penalty"], - vad_service=self.services["vad"] if task["dual_channel"] else None, - use_batch=task["use_batch"], + result, process_time = time_and_tell( + lambda: self.services["transcription"]( + task["input"], + source_lang=task["source_lang"], + model_index=gpu_index, + suppress_blank=False, + vocab=None if task["vocab"] == [] else task["vocab"], + word_timestamps=True, + internal_vad=task["internal_vad"], + repetition_penalty=task["repetition_penalty"], + compression_ratio_threshold=task["compression_ratio_threshold"], + log_prob_threshold=task["log_prob_threshold"], + no_speech_threshold=task["no_speech_threshold"], + condition_on_previous_text=task["condition_on_previous_text"], + vad_service=self.services["vad"] if task["dual_channel"] else None, + ), + func_name="transcription", + debug_mode=debug_mode, ) - result = segments except Exception as e: result = Exception( f"Error in transcription gpu {gpu_index}: {e}\n{traceback.format_exc()}" ) + process_time = None finally: + task["process_times"]["transcription"] = process_time task["transcription_result"] = result task["transcription_done"].set() return None - @time_and_tell - def process_diarization(self, task: dict, gpu_index: int) -> None: + def process_diarization(self, task: dict, gpu_index: int, debug_mode: bool) -> None: """ Process a task of diarization. Args: task (dict): The task and its parameters. gpu_index (int): The GPU index to use for the diarization. + debug_mode (bool): Whether to run in debug mode or not. Returns: None: The task is updated with the result. """ try: - result = self.services["diarization"]( - task["input"], - audio_duration=task["duration"], - oracle_num_speakers=task["num_speakers"], - model_index=gpu_index, - vad_service=self.services["vad"], + result, process_time = time_and_tell( + lambda: self.services["diarization"]( + task["input"], + audio_duration=task["duration"], + oracle_num_speakers=task["num_speakers"], + model_index=gpu_index, + vad_service=self.services["vad"], + ), + func_name="diarization", + debug_mode=debug_mode, ) except Exception as e: result = Exception(f"Error in diarization: {e}\n{traceback.format_exc()}") + process_time = None finally: + task["process_times"]["diarization"] = process_time task["diarization_result"] = result task["diarization_done"].set() return None - @time_and_tell - def process_alignment(self, task: dict, gpu_index: int) -> None: + def process_alignment(self, task: dict, gpu_index: int, debug_mode: bool) -> None: """ Process a task of alignment. Args: task (dict): The task and its parameters. gpu_index (int): The GPU index to use for the alignment. + debug_mode (bool): Whether to run in debug mode or not. Returns: None: The task is updated with the result. """ try: - segments = self.services["alignment"]( - task["input"], - transcript_segments=task["transcription_result"], - source_lang=task["source_lang"], - gpu_index=gpu_index, + result, process_time = time_and_tell( + lambda: self.services["alignment"]( + task["input"], + transcript_segments=task["transcription_result"], + source_lang=task["source_lang"], + gpu_index=gpu_index, + ), + func_name="alignment", + debug_mode=debug_mode, ) except Exception as e: - segments = Exception(f"Error in alignment: {e}\n{traceback.format_exc()}") + result = Exception(f"Error in alignment: {e}\n{traceback.format_exc()}") + process_time = None finally: - task["alignment_result"] = segments + task["process_times"]["alignment"] = process_time + task["alignment_result"] = result task["alignment_done"].set() return None - @time_and_tell def process_post_processing(self, task: dict) -> None: """ Process a task of post processing. @@ -379,6 +475,7 @@ def process_post_processing(self, task: dict) -> None: None: The task is updated with the result. """ try: + total_post_process_time = 0 alignment = task["alignment"] diarization = task["diarization"] dual_channel = task["dual_channel"] @@ -386,12 +483,18 @@ def process_post_processing(self, task: dict) -> None: if dual_channel: left_segments, right_segments = task["transcription_result"] - utterances = self.services[ - "post_processing" - ].dual_channel_speaker_mapping( - left_segments=left_segments, - right_segments=right_segments, + utterances, process_time = time_and_tell( + lambda: self.services[ + "post_processing" + ].dual_channel_speaker_mapping( + left_segments=left_segments, + right_segments=right_segments, + ), + func_name="dual_channel_speaker_mapping", + debug_mode=self.debug_mode, ) + total_post_process_time += process_time + else: segments = ( task["alignment_result"] @@ -399,40 +502,56 @@ def process_post_processing(self, task: dict) -> None: else task["transcription_result"] ) - formatted_segments = format_segments( - segments=segments, - alignment=alignment, - use_batch=task["use_batch"], - word_timestamps=True, + formatted_segments, process_time = time_and_tell( + lambda: format_segments( + segments=segments, + alignment=alignment, + word_timestamps=True, + ), + func_name="format_segments", + debug_mode=self.debug_mode, ) + total_post_process_time += process_time if diarization: - utterances = self.services[ - "post_processing" - ].single_channel_speaker_mapping( - transcript_segments=formatted_segments, - speaker_timestamps=task["diarization_result"], - word_timestamps=word_timestamps, + utterances, process_time = time_and_tell( + lambda: self.services[ + "post_processing" + ].single_channel_speaker_mapping( + transcript_segments=formatted_segments, + speaker_timestamps=task["diarization_result"], + word_timestamps=word_timestamps, + ), + func_name="single_channel_speaker_mapping", + debug_mode=self.debug_mode, ) + total_post_process_time += process_time else: utterances = formatted_segments - final_utterances = self.services[ - "post_processing" - ].final_processing_before_returning( - utterances=utterances, - diarization=diarization, - dual_channel=task["dual_channel"], - timestamps_format=task["timestamps_format"], - word_timestamps=word_timestamps, + final_utterances, process_time = time_and_tell( + lambda: self.services[ + "post_processing" + ].final_processing_before_returning( + utterances=utterances, + diarization=diarization, + dual_channel=task["dual_channel"], + timestamps_format=task["timestamps_format"], + word_timestamps=word_timestamps, + ), + func_name="final_processing_before_returning", + debug_mode=self.debug_mode, ) + total_post_process_time += process_time except Exception as e: final_utterances = Exception( f"Error in post-processing: {e}\n{traceback.format_exc()}" ) + total_post_process_time = None finally: + task["process_times"]["post_processing"] = total_post_process_time task["post_processing_result"] = final_utterances task["post_processing_done"].set() diff --git a/wordcab_transcribe/services/transcribe_service.py b/wordcab_transcribe/services/transcribe_service.py index ae3fe3b..936dea3 100644 --- a/wordcab_transcribe/services/transcribe_service.py +++ b/wordcab_transcribe/services/transcribe_service.py @@ -19,255 +19,17 @@ # and limitations under the License. """Transcribe Service for audio files.""" -import itertools -import math -import os -import zlib -from pathlib import Path -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import List, NamedTuple, Optional, Tuple, Union import numpy as np import torch -import torch.nn.functional as F # noqa N812 -import torchaudio -from ctranslate2 import StorageView -from ctranslate2.models import WhisperGenerationResult from faster_whisper import WhisperModel -from faster_whisper.tokenizer import Tokenizer -from faster_whisper.transcribe import get_ctranslate2_storage from loguru import logger -from torch.utils.data import DataLoader, IterableDataset -from wordcab_transcribe.config import settings -from wordcab_transcribe.logging import time_and_tell from wordcab_transcribe.services.vad_service import VadService from wordcab_transcribe.utils import enhance_audio -class DualChannelInput(NamedTuple): - """Tuple used for dual channel processing. - - The first element is the index of the VAD group. The second element is the audio tensor. - """ - - group_index: int - audio: torch.Tensor - - -# Word implementation from faster-whisper: -# https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py#L24 -class Word(NamedTuple): - """Word unit for word_timestamps option.""" - - start: float - end: float - word: str - probability: float - - -class AudioDataset(IterableDataset): - """Audio Dataset for transcribing audio files in batches.""" - - def __init__( - self, - audio: Union[str, torch.Tensor, List[DualChannelInput]], - chunk_size: int, - hop_length: int, - mel_filters: torch.Tensor, - n_fft: int, - n_samples: int, - sample_rate: int, - ) -> None: - """ - Initialize the Audio Dataset for transcribing audio files in batches. - - Args: - audio (Union[str, torch.Tensor, List[DualChannelInput]]): Audio to transcribe. - chunk_size (int): Size of audio chunks. - hop_length (int): Hop length for the STFT. - mel_filters (torch.Tensor): Mel filters to apply to the STFT. - n_fft (int): Size of the FFT. - n_samples (int): Number of samples to pad the audio. - sample_rate (int): Sample rate of the audio. - """ - if isinstance(audio, list): - if not all(isinstance(a, DualChannelInput) for a in audio): - raise TypeError("Audio must be a list of DualChannelInput.") - - self.chunk_size = chunk_size - self.hop_length = hop_length - self.n_fft = n_fft - self.n_samples = n_samples - self.mel_filters = mel_filters - self.sample_rate = sample_rate - - if isinstance(audio, str): - waveform = self.read_audio(audio) - else: - waveform = audio - - ( - self.indexes, - _audio_chunks, - self.time_offsets, - self.segment_durations, - self.group_ids, - ) = self.create_chunks(waveform) - - self.features = [ - self._log_mel_spectrogram(chunk, padding=self.n_samples - chunk.shape[-1]) - for chunk in _audio_chunks - ] - - def __len__(self) -> int: - """Get the number of audio chunks.""" - return len(self.indexes) - - def __iter__(self) -> Iterator[Dict[str, Union[torch.Tensor, int, float, None]]]: - """Iterate over the audio chunks and yield the features.""" - if self.group_ids is None: - group_ids_iter = itertools.repeat(None) - else: - group_ids_iter = iter(self.group_ids) - - for index, feature, time_offset, segment_duration, group_id in zip( - self.indexes, - self.features, - self.time_offsets, - self.segment_durations, - group_ids_iter, - ): - yield { - "index": index, - "feature": feature, - "time_offset": time_offset, - "segment_duration": segment_duration, - "group_id": group_id, - } - - def read_audio(self, filepath: str) -> torch.Tensor: - """ - Read an audio file and return the audio tensor. - - Args: - filepath (str): Path to the audio file. - - Returns: - torch.Tensor: Audio tensor. - """ - wav, sr = torchaudio.load(filepath) - - if wav.size(0) > 1: - wav = wav.mean(dim=0, keepdim=True) - - if sr != self.sample_rate: - transform = torchaudio.transforms.Resample( - orig_freq=sr, new_freq=self.sample_rate - ) - wav = transform(wav) - sr = self.sample_rate - - return wav.squeeze(0) - - @time_and_tell - def create_chunks( - self, waveform: Union[torch.Tensor, List[DualChannelInput]] - ) -> Tuple[ - List[int], List[torch.Tensor], List[int], List[float], Union[None, List[int]] - ]: - """ - Create 30-second chunks from the audio file loaded as a tensor. - - Args: - waveform (Union[torch.Tensor, List[DualChannelInput]]): Audio to transcribe. - - Returns: - Tuple[List[int], List[torch.Tensor], List[int], List[float], Union[None, List[int]]]: - """ - if isinstance(waveform, torch.Tensor): - num_segments = math.ceil(waveform.size(0) / self.n_samples) - segments = [ - waveform[i * self.n_samples : (i + 1) * self.n_samples] - for i in range(num_segments) - ] - group_ids = None - - elif isinstance(waveform, list): - segments = [segment.audio for segment in waveform] - group_ids = [segment.group_index for segment in waveform] - num_segments = len(segments) - - indexes = [i for i in range(num_segments)] - time_offsets = [(i * self.chunk_size) for i in range(num_segments)] - segment_durations = [ - self.chunk_size - if len(segment) == self.n_samples - else len(segment) / self.sample_rate - for segment in segments - ] - - return indexes, segments, time_offsets, segment_durations, group_ids - - def _log_mel_spectrogram( - self, audio: torch.Tensor, padding: int = 0 - ) -> torch.Tensor: - """ - Compute the log-Mel spectrogram of a given audio tensor. - - Args: - audio (torch.Tensor): Audio tensor of shape (n_samples,). - padding (int): Number of samples to pad the audio. - - Returns: - torch.Tensor: Log-Mel spectrogram of shape (n_mels, T). - """ - if padding > 0: - audio = F.pad(audio, (0, padding)) - - window = torch.hann_window(self.n_fft).to(audio.device) - stft = torch.stft( - audio, self.n_fft, self.hop_length, window=window, return_complex=True - ) - - magnitudes = stft[..., :-1].abs() ** 2 - mel_spec = self.mel_filters @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - - return log_spec - - -class FallBackDataset(IterableDataset): - """Custom Dataset for transcribing fallback segments in batches.""" - - def __init__(self, failed_segments: List[Dict[str, Any]]) -> None: - """ - Initialize the Dataset. - - Args: - failed_segments (List[Dict[str, Any]]): List of failed segments. - """ - self.segments = failed_segments - - def __iter__(self) -> Dict[str, Any]: - """ - Iterate over the failed segments and yield the features. - - Yields: - Dict[str, Any]: Dictionary containing the features. - A segment looks like this: - { - "index": 0, # Index of the segment in the original list of segments. - "feature": torch.Tensor, # Tensor of shape (n_mels, T). - "time_offset": 0, - "segment_duration": 30.0, - } - """ - yield from self.segments - - class FasterWhisperModel(NamedTuple): """Faster Whisper Model.""" @@ -284,31 +46,31 @@ def __init__( compute_type: str, device: str, device_index: Union[int, List[int]], + extra_languages: Union[List[str], None] = None, + extra_languages_model_paths: Union[List[str], None] = None, ) -> None: """Initialize the Transcribe Service. This service uses the WhisperModel from faster-whisper to transcribe audio files. Args: - model_path (str): Path to the model checkpoint. This can be a local path or a URL. - compute_type (str): Compute type to use for inference. Can be "int8", "int8_float16", "int16" or "float_16". - device (str): Device to use for inference. Can be "cpu" or "cuda". - device_index (Union[int, List[int]]): Index of the device to use for inference. + model_path (str): + Path to the model checkpoint. This can be a local path or a URL. + compute_type (str): + Compute type to use for inference. Can be "int8", "int8_float16", "int16" or "float_16". + device (str): + Device to use for inference. Can be "cpu" or "cuda". + device_index (Union[int, List[int]]): + Index of the device to use for inference. + extra_languages (Union[List[str], None]): + List of extra languages to transcribe. Defaults to None. + extra_languages_model_paths (Union[List[str], None]): + List of paths to the extra language models. Defaults to None. """ self.device = device self.compute_type = compute_type self.model_path = model_path - # self.models = {} - # for idx in device_index: - # model = WhisperModel( - # self.model_path, - # device=self.device, - # device_index=idx, - # compute_type=self.compute_type, - # ) - # self.models[idx] = FasterWhisperModel(model=model, lang="multi") - # logger.debug(f"Loaded {len(self.models)} models for transcription.") self.model = WhisperModel( self.model_path, device=self.device, @@ -316,29 +78,8 @@ def __init__( compute_type=self.compute_type, ) - self.extra_lang = settings.extra_languages - self.extra_lang_models = settings.extra_languages_model_paths - - self._batch_size = 8 # TODO: Make this configurable - self.sample_rate = 16000 - - self.n_fft = 400 - self.n_mels = 80 - self.chunk_size = 30 - self.hop_length = 160 - - self.n_samples = self.sample_rate * self.chunk_size - self.tokens_per_second = self.sample_rate // self.hop_length - - assets_dir = Path(__file__).parent.parent / "assets" / "mel_filters.npz" - with np.load(str(assets_dir)) as f: - self.mel_filters = torch.from_numpy(f[f"mel_{self.n_mels}"]) - - self.compression_ratio_threshold = 2.4 - self.log_probability_threshold = -0.8 - - self.prepend_punctuation = "\"'“¿([{-" - self.append_punctuation = "\"'.。,,!!??::”)]}、" + self.extra_lang = extra_languages + self.extra_lang_models = extra_languages_model_paths def __call__( self, @@ -348,29 +89,50 @@ def __call__( source_lang: str, model_index: int, suppress_blank: bool = False, - vocab: Optional[List[str]] = None, + vocab: Union[List[str], None] = None, word_timestamps: bool = True, internal_vad: bool = False, repetition_penalty: float = 1.0, - vad_service: Optional[VadService] = None, - use_batch: bool = True, + compression_ratio_threshold: float = 2.4, + log_prob_threshold: float = -1.0, + no_speech_threshold: float = 0.6, + condition_on_previous_text: bool = True, + vad_service: Union[VadService, None] = None, ) -> Union[List[dict], List[List[dict]]]: """ Run inference with the transcribe model. Args: - audio (Union[str, torch.Tensor, Tuple[str, str], Tuple[torch.Tensor, torch.Tensor]]): Audio file path or - audio tensor. If a tuple is passed, the task is assumed to be a dual_channel task and the tuple should - contain the paths to the two audio files. - source_lang (str): Language of the audio file. - model_index (int): Index of the model to use. - suppress_blank (bool): Whether to suppress blank at the beginning of the sampling. - vocab (Optional[List[str]]): Vocabulary to use during generation if not None. - word_timestamps (bool): Whether to return word timestamps. - internal_vad (bool): Whether to use faster-whisper's VAD or not. - repetition_penalty (float): Repetition penalty to use during generation beamed search. - vad_service (Optional[VADService]): VADService to use for voice activity detection in the dual_channel case. - use_batch (bool): Whether to use batch inference. + audio (Union[str, torch.Tensor, Tuple[str, str], Tuple[torch.Tensor, torch.Tensor]]): + Audio file path or audio tensor. If a tuple is passed, the task is assumed + to be a dual_channel task and the tuple should contain the paths to the two audio files. + source_lang (str): + Language of the audio file. + model_index (int): + Index of the model to use. + suppress_blank (bool): + Whether to suppress blank at the beginning of the sampling. + vocab (Union[List[str], None]): + Vocabulary to use during generation if not None. Defaults to None. + word_timestamps (bool): + Whether to return word timestamps. + internal_vad (bool): + Whether to use faster-whisper's VAD or not. + repetition_penalty (float): + Repetition penalty to use during generation beamed search. + compression_ratio_threshold (float): + If the gzip compression ratio is above this value, treat as failed. + log_prob_threshold (float): + If the average log probability over sampled tokens is below this value, treat as failed. + no_speech_threshold (float): + If the no_speech probability is higher than this value AND the average log probability + over sampled tokens is below `log_prob_threshold`, consider the segment as silent. + condition_on_previous_text (bool): + If True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone + to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + vad_service (Union[VadService, None]): + VADService to use for voice activity detection in the dual_channel case. Defaults to None. Returns: Union[List[dict], List[List[dict]]]: List of transcriptions. If the task is a dual_channel task, @@ -405,7 +167,6 @@ def __call__( # lang=source_lang, # ) - # if not use_batch and not isinstance(audio, tuple): if ( vocab is not None and isinstance(vocab, list) @@ -426,8 +187,12 @@ def __call__( language=source_lang, initial_prompt=prompt, repetition_penalty=repetition_penalty, - suppress_blank=False, - word_timestamps=True, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, + suppress_blank=suppress_blank, + word_timestamps=word_timestamps, vad_filter=internal_vad, vad_parameters=dict( threshold=0.5, @@ -437,13 +202,6 @@ def __call__( window_size_samples=512, ), ) - # segments, _ = self.models[model_index].model.transcribe( - # audio, - # language=source_lang, - # initial_prompt=prompt, - # suppress_blank=False, - # word_timestamps=True, - # ) segments = list(segments) if not segments: @@ -455,6 +213,10 @@ def __call__( language=source_lang, initial_prompt=prompt, repetition_penalty=repetition_penalty, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, suppress_blank=False, word_timestamps=True, vad_filter=False if internal_vad else True, @@ -475,394 +237,6 @@ def __call__( ) ) - # else: - # tokenizer = Tokenizer( - # self.model.hf_tokenizer, - # self.model.model.is_multilingual, - # task="transcribe", - # language=source_lang, - # ) - - # if isinstance(audio, tuple): - # outputs = [] - # for audio_index, audio_file in enumerate(audio): - # outputs.append( - # self._transcribe_dual_channel( - # self.model, - # tokenizer, - # audio_file, - # audio_index, - # vad_service, - # ) - # ) - - # else: - # outputs = self.pipeline( - # self.model, - # tokenizer, - # audio, - # self._batch_size, - # suppress_blank, - # word_timestamps, - # ) - - return outputs - - @time_and_tell - def pipeline( - self, - model: WhisperModel, - tokenizer: Tokenizer, - audio: Union[str, torch.Tensor, List[DualChannelInput]], - batch_size: int, - suppress_blank: bool = True, - word_timestamps: bool = False, - ) -> List[dict]: - """ - Transcription pipeline for audio chunks in batches. - - Args: - model (WhisperModel): Model to use for inference. - tokenizer (Tokenizer): Tokenizer to use for inference. - audio (Union[str, torch.Tensor, List[DualChannelInput]]): Audio file path, audio tensor or list of - DualChannelInput objects. - batch_size (int): Batch size to use for inference. - suppress_blank (bool): Whether to suppress blank at the beginning of the sampling. - word_timestamps (bool): Whether to return word timestamps. - - Returns: - List[dict]: List of segments with the following keys: "start", "end", "text". - """ - dataset = AudioDataset( - audio=audio, - chunk_size=self.chunk_size, - hop_length=self.hop_length, - mel_filters=self.mel_filters, - n_fft=self.n_fft, - n_samples=self.n_samples, - sample_rate=self.sample_rate, - ) - dataloader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=self._collate_fn, - ) - - _outputs = [None for _ in range(len(dataset))] - - # The first pass of inference is done with non-greedy settings to achieve better results. - beam_size = 5 - num_hypotheses = 1 - patience = 1.0 - sampling_top_k = 1 - temperature = 1.0 - stop_temperature = None - - while True: - outputs_that_need_reprocessing = [] - - for batch in dataloader: - batch_outputs = self._generate_segment_batched( - model=model, - features=batch["features"], - time_offsets=batch["time_offsets"], - segment_durations=batch["segment_durations"], - group_ids=batch["group_ids"], - tokenizer=tokenizer, - beam_size=beam_size, - num_hypotheses=num_hypotheses, - patience=patience, - sampling_top_k=sampling_top_k, - suppress_blank=suppress_blank, - temperature=temperature, - last_chance_inference=False if stop_temperature != 1.0 else True, - word_timestamps=word_timestamps, - ) - - for output_index, output in enumerate(batch_outputs): - if output["need_fallback"]: - outputs_that_need_reprocessing.append( - { - "index": batch["indexes"][output_index], - "feature": batch["features"][output_index], - "time_offset": batch["time_offsets"][output_index], - "segment_duration": batch["segment_durations"][ - output_index - ], - "group_id": batch["group_ids"][output_index], - } - ) - else: - _outputs[batch["indexes"][output_index]] = output["segments"] - - if len(outputs_that_need_reprocessing) > 0 and stop_temperature != 1.0: - dataloader = DataLoader( - FallBackDataset(outputs_that_need_reprocessing), - batch_size=batch_size, - collate_fn=self._collate_fn, - ) - # The second pass of inference is done with greedy settings to speed up the process. - beam_size = 1 - num_hypotheses = 5 - sampling_top_k = 0 - temperature = (temperature + 0.2) if temperature != 1.0 else 0.2 - stop_temperature = temperature # Used to stop the loop if the temperature reaches 1.0 again. - else: - break # All segments have been processed successfully. - - outputs = list(itertools.chain.from_iterable(_outputs)) - - return outputs - - # This is an adapted version of the faster-whisper transcription pipeline: - # https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py - @time_and_tell - def _generate_segment_batched( - self, - model: WhisperModel, - features: torch.Tensor, - time_offsets: List[float], - segment_durations: List[float], - group_ids: List[int], - tokenizer: Tokenizer, - beam_size: int = 5, - initial_prompt: Optional[str] = None, - last_chance_inference: bool = False, - length_penalty: float = 1.0, - patience: float = 1.0, - prefix: Optional[str] = None, - num_hypotheses: int = 1, - sampling_top_k: int = 1, - suppress_blank: bool = True, - temperature: float = 1.0, - without_timestamps: bool = False, - word_timestamps: bool = False, - ) -> List[dict]: - """ - Use the ctranslate2 Whisper model to generate text from audio chunks. - - Args: - model (WhisperModel): Model to use for inference. - features (torch.Tensor): List of audio chunks. - time_offsets (List[float]): Time offsets for the audio chunks. - segment_durations (List[float]): Durations of the audio chunks. - group_ids (List[int]): Group ids of the audio chunks. - tokenizer (Tokenizer): Tokenizer to use for encoding the text. - beam_size (int): Beam size to use for beam search. - last_chance_inference (bool): Whether to accept the result of the inference even if not perfect. - length_penalty (float): Length penalty to use for beam search. - initial_prompt (Optional[str]): Initial prompt to use for the generation. - num_hypotheses (int): Number of hypotheses used by generate. - patience (float): Patience to use for beam search. - prefix (Optional[str]): Prefix to use for the generation. - sampling_top_k (int): Sampling top k to use for sampling. - suppress_blank (bool): Whether to suppress blank output of the sampling. - temperature (float): Temperature to use for sampling. - without_timestamps (bool): Whether to remove timestamps from the generated text. - word_timestamps (bool): Whether to use word timestamps instead of character timestamps. - - Returns: - List[dict]: List of segments with the following keys: "segments", "need_fallback". - """ - if "TOKENIZERS_PARALLELISM" not in os.environ: - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - batch_size = features.size(0) - - all_tokens = [] - prompt_reset_since = 0 - - if initial_prompt is not None: - initial_prompt = " " + initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) - - previous_tokens = all_tokens[prompt_reset_since:] - prompt = model.get_prompt( - tokenizer, - previous_tokens, - without_timestamps=without_timestamps, - prefix=prefix, - ) - - features = self._encode_batch( - self.model, features, word_timestamps=word_timestamps - ) - - # TODO: We access the inherited ctranslate2 model for generation here. This is not ideal. - result: WhisperGenerationResult = model.model.generate( - features, - [prompt] * batch_size, - beam_size=beam_size, - patience=patience, - num_hypotheses=num_hypotheses, - length_penalty=length_penalty, - return_scores=True, - return_no_speech_prob=True, - suppress_blank=suppress_blank, - sampling_temperature=temperature, - sampling_topk=sampling_top_k, - ) - - outputs = [] - for res, time_offset, segment_duration, group_id in zip( - result, time_offsets, segment_durations, group_ids - ): - current_segments = [] - tokens = res.sequences_ids[0] - segment_score = res.scores[0] - _text = tokenizer.decode(tokens).strip() - - compression_ratio, average_log_probability = self._get_quality_metrics( - tokens, - _text, - segment_score, - length_penalty, - ) - - # We check if the segment is valid based on the metrics thresholds. - # Or if it is the last chance inference, we will accept the result even if not perfect. - if ( - average_log_probability > self.log_probability_threshold - and compression_ratio < self.compression_ratio_threshold - ) or last_chance_inference: - single_timestamp_ending = ( - len(tokens) >= 2 - and tokens[-2] < tokenizer.timestamp_begin - and tokens[-1] >= tokenizer.timestamp_begin - ) - - consecutive_timestamps = [ - i - for i in range(len(tokens)) - if i > 0 - and tokens[i] >= tokenizer.timestamp_begin - and tokens[i - 1] >= tokenizer.timestamp_begin - ] - - if len(consecutive_timestamps) > 0: - slices = list(consecutive_timestamps) - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = ( - sliced_tokens[0] - tokenizer.timestamp_begin - ) - end_timestamp_position = ( - sliced_tokens[-1] - tokenizer.timestamp_begin - ) - start_time = time_offset + start_timestamp_position * 0.02 - end_time = time_offset + end_timestamp_position * 0.02 - - current_segments.append( - dict( - start=start_time, - end=end_time, - tokens=sliced_tokens, - group_id=group_id, - ) - ) - last_slice = current_slice - else: - duration = segment_duration - timestamps = [ - token for token in tokens if token >= tokenizer.timestamp_begin - ] - if ( - len(timestamps) > 0 - and timestamps[-1] != tokenizer.timestamp_begin - ): - last_timestamp_position = ( - timestamps[-1] - tokenizer.timestamp_begin - ) - duration = last_timestamp_position * 0.02 - - current_segments.append( - dict( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - group_id=group_id, - ) - ) - - outputs.append( - dict( - segments=self._decode_batch(current_segments, tokenizer), - need_fallback=len(current_segments) == 0, - ) - ) - - if word_timestamps: - segment_sizes = [ - int(segment_duration / (self.hop_length / self.sample_rate)) - for segment_duration in segment_durations - ] - self._add_word_timestamps( - outputs, - tokenizer, - features, - segment_sizes, - time_offsets, - self.prepend_punctuation, - self.append_punctuation, - ) - - return outputs - - def _encode_batch( - self, model: WhisperModel, features: torch.Tensor, word_timestamps: bool - ) -> StorageView: - """Encode the features using the model encoder. - - We encode the features only if word timestamps are enabled. - Otherwise, we just return the features formatted as a StorageView. - - Args: - model (WhisperModel): Model to use to encode the features. - features (torch.Tensor): Features to encode. - word_timestamps (bool): Whether to encode the features or not. - - Returns: - StorageView: Encoded features. - """ - features = get_ctranslate2_storage(features) - - if ( - word_timestamps - ): # We encode the features to re-use the encoder output later. - features = model.model.encode(features, to_cpu=False) - - return features - - def _decode_batch(self, outputs: List[dict], tokenizer: Tokenizer) -> List[dict]: - """ - Extract the token ids from the sequences ids and decode them using the tokenizer. - - Args: - outputs (List[dict]): List of outputs from the model. - tokenizer (Tokenizer): Tokenizer to use to decode the token ids. - - Returns: - List[str]: List of decoded texts. - """ - if len(outputs) == 0: - return outputs - - tokens_to_decode = [ - [token for token in out["tokens"] if token < tokenizer.eot] - for out in outputs - ] - # TODO: We call the inherited tokenizer here, because faster_whisper tokenizer - # doesn't have the decode_batch method. We should fix this in the future. - decoded_tokens = tokenizer.tokenizer.decode_batch(tokens_to_decode) - - for out, text in zip(outputs, decoded_tokens): - out["text"] = text - return outputs def dual_channel( @@ -946,356 +320,3 @@ def dual_channel( final_transcript.append(segment_dict) return final_transcript - - def _transcribe_dual_channel( - self, - model: WhisperModel, - tokenizer: Tokenizer, - audio: Union[str, torch.Tensor], - speaker_id: int, - vad_service: VadService, - ) -> List[dict]: - """ - Transcribe an audio file with two channels. - - Args: - model (WhisperModel): Model to use to transcribe the audio. - tokenizer (Tokenizer): Tokenizer to use to decode the token ids. - audio (Union[str, torch.Tensor]): Audio file path or loaded audio. - speaker_id (int): Speaker ID used in the diarization. - vad_service (VadService): VAD service. - - Returns: - List[dict]: List of transcribed segments. - """ - enhanced_audio = enhance_audio(audio, apply_agc=True, apply_bandpass=False) - grouped_segments, audio = vad_service(enhanced_audio) - - final_transcript = [] - silence_padding = torch.from_numpy(np.zeros(int(3 * self.sample_rate))).float() - - prepared_groups = [] - for group_id, group in enumerate(grouped_segments): - audio_segments = [] - for segment in group: - audio_segments.extend( - [audio[segment["start"] : segment["end"]], silence_padding] - ) - - prepared_groups.append( - DualChannelInput(group_id, torch.cat(audio_segments)) - ) - - segments = self.pipeline( - model, tokenizer, prepared_groups, self._batch_size, False, True - ) - - for segment in segments: - group_timestamps_base = ( - grouped_segments[segment["group_id"]][0]["start"] / self.sample_rate - ) - group_timestamps_shift = segment["group_id"] * 30 - - segment_dict = { - "start": None, - "end": None, - "text": segment["text"], - "words": [], - "speaker": speaker_id, - } - - for word in segment["words"]: - word_start_adjusted = ( - group_timestamps_base + word["start"] - group_timestamps_shift - ) - word_end_adjusted = ( - group_timestamps_base + word["end"] - group_timestamps_shift - ) - segment_dict["words"].append( - { - "start": word_start_adjusted, - "end": word_end_adjusted, - "word": word["word"], - } - ) - - if ( - segment_dict["start"] is None - or word_start_adjusted < segment_dict["start"] - ): - segment_dict["start"] = word_start_adjusted - - if ( - segment_dict["end"] is None - or word_end_adjusted > segment_dict["end"] - ): - segment_dict["end"] = word_end_adjusted - - final_transcript.append(segment_dict) - - return final_transcript - - @time_and_tell - def _add_word_timestamps( - self, - outputs: List[dict], - tokenizer: Tokenizer, - encoder_output: StorageView, - segment_sizes: List[int], - time_offsets: List[float], - prepend_punctuation: str, - append_punctuation: str, - ) -> None: - """ - Add word timestamps to the segments. - - Args: - outputs (List[dict]): List of outputs from the model. - tokenizer (Tokenizer): Tokenizer to use to decode the token ids. - encoder_output (StorageView): Encoder output. - segment_sizes (List[int]): List of segment sizes. - time_offsets (List[float]): List of time offsets. - prepend_punctuation (str): Punctuation to prepend to the text. - append_punctuation (str): Punctuation to append to the text. - """ - text_tokens_per_output = [] - for out in outputs: - text_tokens_per_segment = [ - [token for token in segment["tokens"] if token < tokenizer.eot] - for segment in out["segments"] - ] - text_tokens_per_output.append(text_tokens_per_segment) - - alignments = self._find_alignment( - encoder_output, text_tokens_per_output, tokenizer, segment_sizes - ) - self._merge_punctuation(alignments, prepend_punctuation, append_punctuation) - - for out, alignment, text_tokens_per_segment, time_offset in zip( - outputs, alignments, text_tokens_per_output, time_offsets - ): - if out["need_fallback"]: - continue - - word_index = 0 - - for segment_idx, segment in enumerate(out["segments"]): - saved_tokens = 0 - words = [] - - if isinstance(alignment, int): - alignment = [alignment] - - while word_index < len(alignment) and saved_tokens < len( - text_tokens_per_segment[segment_idx] - ): - timing = alignment[word_index] - - if timing["word"]: - words.append( - dict( - word=timing["word"], - start=round(time_offset + timing["start"], 2), - end=round(time_offset + timing["end"], 2), - probability=timing["probability"], - ) - ) - - saved_tokens += len(timing["tokens"]) - word_index += 1 - - if len(words) > 0: - segment["start"] = words[0]["start"] - segment["end"] = words[-1]["end"] - - segment["words"] = words - - def _find_alignment( - self, - encoder_output: StorageView, - text_tokens_per_output: List[List[int]], - tokenizer: Tokenizer, - segment_sizes: List[int], - median_filter_width: int = 7, - ) -> List[List[dict]]: - """ - Find the alignment between the encoder output and the text tokens in a batch. - - Args: - encoder_output (StorageView): Encoder output. - text_tokens_per_output (List[List[int]]): List of text tokens per output. - tokenizer (Tokenizer): Tokenizer to use to decode the token ids. - segment_sizes (List[int]): List of segment sizes. - median_filter_width (int): Width of the median filter to apply on the alignment. - - Returns: - List[List[dict]]: List of alignments per output. - """ - text_tokens_per_output = [ - list(itertools.chain.from_iterable(list_of_tokens)) - for list_of_tokens in text_tokens_per_output - ] - - results = self.model.model.align( - encoder_output, - tokenizer.sot_sequence, - text_tokens_per_output, - segment_sizes, - median_filter_width=median_filter_width, - ) - - final_alignments = [] - for res, text_tokens in zip(results, text_tokens_per_output): - words, word_tokens = tokenizer.split_to_word_tokens( - text_tokens + [tokenizer.eot] - ) - word_boundaries = np.pad( - np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) - ) - if len(word_boundaries) <= 1: - final_alignments.append([]) - continue - - alignments = res.alignments - text_indices = np.array([pair[0] for pair in alignments]) - time_indices = np.array([pair[1] for pair in alignments]) - - jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( - bool - ) - jump_times = time_indices[jumps] / self.tokens_per_second - start_times = jump_times[word_boundaries[:-1]] - end_times = jump_times[word_boundaries[1:]] - - text_token_probs = res.text_token_probs - word_probabilities = [ - np.mean(text_token_probs[i:j]) - for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) - ] - - word_durations = end_times - start_times - word_durations = word_durations[word_durations.nonzero()] - - if len(word_durations) > 0: - median_duration = np.median(word_durations) - max_duration = median_duration * 2 - - if len(word_durations) >= 2 and word_durations[1] > max_duration: - boundary = max(end_times[2] / 2, end_times[2] - max_duration) - end_times[0] = start_times[1] = boundary - - if ( - len(word_durations) >= 1 - and end_times[0] - start_times[0] > max_duration - ): - start_times[0] = max(0, end_times[0] - max_duration) - - final_alignments.append( - [ - dict( - word=word, - tokens=tokens, - start=start, - end=end, - probability=probability, - ) - for word, tokens, start, end, probability in zip( - words, word_tokens, start_times, end_times, word_probabilities - ) - ] - ) - - return final_alignments - - def _merge_punctuation( - self, alignments: List[List[dict]], prepended: str, appended: str - ) -> None: - """ - Fix punctuation boundaries for the alignments. - - Args: - alignments (List[List[dict]]): List of alignments. - prepended (str): Prepended punctuation. - appended (str): Appended punctuation. - """ - for alignment in alignments: - # merge prepended punctuations - i = len(alignment) - 2 - j = len(alignment) - 1 - while i >= 0: - previous = alignment[i] - following = alignment[j] - if ( - previous["word"].startswith(" ") - and previous["word"].strip() in prepended - ): - # prepend it to the following word - following["word"] = previous["word"] + following["word"] - following["tokens"] = previous["tokens"] + following["tokens"] - previous["word"] = "" - previous["tokens"] = [] - else: - j = i - i -= 1 - - # merge appended punctuations - i = 0 - j = 1 - while j < len(alignment): - previous = alignment[i] - following = alignment[j] - if not previous["word"].endswith(" ") and following["word"] in appended: - # append it to the previous word - previous["word"] = previous["word"] + following["word"] - previous["tokens"] = previous["tokens"] + following["tokens"] - following["word"] = "" - following["tokens"] = [] - else: - i = j - j += 1 - - def _get_quality_metrics( - self, tokens: List[int], text: str, score: float, length_penalty: float - ) -> Tuple[float, float]: - """ - Get the compression ratio and the average log probability of the outputs to score them. - - Args: - tokens (List[int]): List of token ids. - text (str): Decoded text. - score (float): Score of the sequence. - length_penalty (float): Length penalty to apply to the average log probability. - - Returns: - Tuple[float, float]: Compression ratio and average log probability. - """ - text_bytes = text.encode("utf-8") - compression_ratio = len(text_bytes) / len(zlib.compress(text_bytes)) - - seq_len = len(tokens) - cumulative_log_probability = score * (seq_len**length_penalty) - average_log_probability = cumulative_log_probability / (seq_len + 1) - - return compression_ratio, average_log_probability - - def _collate_fn( - self, items: List[Dict[str, Union[torch.Tensor, int, float, None]]] - ) -> Dict[str, Union[torch.Tensor, List[int], List[float], List[None]]]: - """ - Collator function for the dataloader. - - Args: - items (List[Dict[str, Union[int, torch.Tensor, List[float]]]]): List of items to collate. - - Returns: - Dict[str, Union[torch.Tensor, List[int], List[float], List[None]]]: Collated items. - """ - collated_items = { - "indexes": [item["index"] for item in items], - "features": torch.stack([item["feature"] for item in items]), - "time_offsets": [item["time_offset"] for item in items], - "segment_durations": [item["segment_duration"] for item in items], - "group_ids": [item["group_id"] for item in items], - } - - return collated_items diff --git a/wordcab_transcribe/utils.py b/wordcab_transcribe/utils.py index c0633d5..f053865 100644 --- a/wordcab_transcribe/utils.py +++ b/wordcab_transcribe/utils.py @@ -489,7 +489,7 @@ def format_punct(text: str): def format_segments( - segments: list, alignment: bool, use_batch: bool, word_timestamps: bool + segments: list, alignment: bool, word_timestamps: bool ) -> List[dict]: """ Format the segments to a list of dicts with start, end and text keys. Optionally include word timestamps. @@ -497,7 +497,6 @@ def format_segments( Args: segments (list): List of segments. alignment (bool): Whether the segments have been aligned. Used to format the word timestamps correctly. - use_batch (bool): Whether the segments are from a batch. Used to format the word timestamps correctly. word_timestamps (bool): Whether to include word timestamps. Returns: @@ -515,26 +514,15 @@ def format_segments( if alignment: segment_dict["words"] = segment["words"] else: - if use_batch: - _words = [ - { - "word": word["word"].strip(), - "start": word["start"], - "end": word["end"], - "score": round(word["probability"], 2), - } - for word in segment["words"] - ] - else: - _words = [ - { - "word": word.word.strip(), - "start": word.start, - "end": word.end, - "score": round(word.probability, 2), - } - for word in segment["words"] - ] + _words = [ + { + "word": word.word.strip(), + "start": word.start, + "end": word.end, + "score": round(word.probability, 2), + } + for word in segment["words"] + ] segment_dict["words"] = _words formatted_segments.append(segment_dict)