From be7993998f76948e3ea5a3b3fb602bc16f953cae Mon Sep 17 00:00:00 2001 From: michaeldecent2 <111002205+MichaelDecent@users.noreply.github.com> Date: Mon, 4 Nov 2024 21:19:17 +0100 Subject: [PATCH] swarm - feat: Implement Whisper Large V3 model --- .../llms/concrete/WhisperLargeModel.py | 144 ++++++++++++++++++ .../unit/llms/WhisperLargeModel_unit_test.py | 141 +++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 pkgs/swarmauri/swarmauri/llms/concrete/WhisperLargeModel.py create mode 100644 pkgs/swarmauri/tests/unit/llms/WhisperLargeModel_unit_test.py diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/WhisperLargeModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/WhisperLargeModel.py new file mode 100644 index 000000000..812e4783f --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/WhisperLargeModel.py @@ -0,0 +1,144 @@ +from typing import List, Literal, Dict +import requests +import asyncio +import aiohttp +from swarmauri.llms.base.LLMBase import LLMBase + + +class WhisperLargeModel(LLMBase): + """ + Implementation of Whisper Large V3 model using HuggingFace's Inference API + https://huggingface.co/openai/whisper-large-v3 + """ + + allowed_models: List[str] = ["openai/whisper-large-v3"] + name: str = "openai/whisper-large-v3" + type: Literal["WhisperLargeModel"] = "WhisperLargeModel" + api_key: str + + __API_URL: str = "https://api-inference.huggingface.co/models/openai/whisper-large-v3" + + def predict( + self, + audio_path: str, + task: Literal["transcription", "translation"] = "transcription", + ) -> str: + """ + Process a single audio file using the Hugging Face Inference API + + Args: + audio_path (str): Path to the audio file + task (Literal["transcription", "translation"]): Task to perform + + Returns: + str: Transcribed or translated text + """ + headers = {"Authorization": f"Bearer {self.api_key}"} + + if task not in ["transcription", "translation"]: + raise ValueError( + f"Task {task} not supported. Choose from ['transcription', 'translation']" + ) + + with open(audio_path, "rb") as audio_file: + data = audio_file.read() + + params = {"task": task} + if task == "translation": + params["language"] = "en" + + response = requests.post( + self.__API_URL, headers=headers, data=data, params=params + ) + + if response.status_code != 200: + raise Exception( + f"API request failed with status code {response.status_code}: {response.text}" + ) + + result = response.json() + + if isinstance(result, dict): + return result.get("text", "") + elif isinstance(result, list) and len(result) > 0: + return result[0].get("text", "") + else: + raise Exception("Unexpected API response format") + + async def apredict( + self, + audio_path: str, + task: Literal["transcription", "translation"] = "transcription", + ) -> str: + """ + Asynchronously process a single audio file + """ + headers = {"Authorization": f"Bearer {self.api_key}"} + + if task not in ["transcription", "translation"]: + raise ValueError( + f"Task {task} not supported. Choose from ['transcription', 'translation']" + ) + + with open(audio_path, "rb") as audio_file: + data = audio_file.read() + + params = {"task": task} + if task == "translation": + params["language"] = "en" + + async with aiohttp.ClientSession() as session: + async with session.post( + self.__API_URL, headers=headers, data=data, params=params + ) as response: + if response.status != 200: + raise Exception( + f"API request failed with status code {response.status}: {await response.text()}" + ) + + result = await response.json() + + if isinstance(result, dict): + return result.get("text", "") + elif isinstance(result, list) and len(result) > 0: + return result[0].get("text", "") + else: + raise Exception("Unexpected API response format") + + def batch( + self, + path_task_dict: Dict[str, Literal["transcription", "translation"]], + ) -> List[str]: + """ + Synchronously process multiple audio files + """ + return [ + self.predict(audio_path=path, task=task) + for path, task in path_task_dict.items() + ] + + async def abatch( + self, + path_task_dict: Dict[str, Literal["transcription", "translation"]], + max_concurrent: int = 5, + ) -> List[str]: + """ + Process multiple audio files in parallel with controlled concurrency + + Args: + path_task_dict (Dict[str, Literal["transcription", "translation"]]): + Dictionary mapping file paths to tasks + max_concurrent (int): Maximum number of concurrent requests + + Returns: + List[str]: List of transcribed/translated texts + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_audio(path: str, task: str) -> str: + async with semaphore: + return await self.apredict(audio_path=path, task=task) + + tasks = [process_audio(path, task) for path, task in path_task_dict.items()] + + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/tests/unit/llms/WhisperLargeModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/WhisperLargeModel_unit_test.py new file mode 100644 index 000000000..882b5cb21 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/WhisperLargeModel_unit_test.py @@ -0,0 +1,141 @@ +import logging +import pytest +import os +from swarmauri.llms.concrete.WhisperLargeModel import WhisperLargeModel as LLM +from swarmauri.utils.timeout_wrapper import timeout +from pathlib import Path +from dotenv import load_dotenv + +load_dotenv() + +API_KEY = os.getenv("HUGGINGFACE_TOKEN") + +# Get the current working directory +root_dir = Path(__file__).resolve().parents[2] + +# Construct file paths dynamically +file_path = os.path.join(root_dir, "static", "test.mp3") +file_path2 = os.path.join(root_dir, "static", "test_fr.mp3") + + +@pytest.fixture(scope="module") +def whisperlarge_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +def get_allowed_models(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(whisperlarge_model): + assert whisperlarge_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(whisperlarge_model): + assert whisperlarge_model.type == "WhisperLargeModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(whisperlarge_model): + assert whisperlarge_model.id == LLM.model_validate_json(whisperlarge_model.model_dump_json()).id + + +@timeout(5) +@pytest.mark.unit +def test_default_name(whisperlarge_model): + assert whisperlarge_model.name == "openai/whisper-large-v3" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_audio_transcription(whisperlarge_model, model_name): + model = whisperlarge_model + model.name = model_name + + prediction = model.predict(audio_path=file_path) + + logging.info(prediction) + + assert type(prediction) is str + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_audio_translation(whisperlarge_model, model_name): + model = whisperlarge_model + model.name = model_name + + prediction = model.predict( + audio_path=file_path, + task="translation", + ) + + logging.info(prediction) + + assert type(prediction) is str + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_apredict(whisperlarge_model, model_name): + whisperlarge_model.name = model_name + + prediction = await whisperlarge_model.apredict( + audio_path=file_path, + task="translation", + ) + + logging.info(prediction) + assert type(prediction) is str + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_batch(whisperlarge_model, model_name): + model = whisperlarge_model + model.name = model_name + + path_task_dict = { + file_path: "translation", + file_path2: "transcription", + } + + results = model.batch(path_task_dict=path_task_dict) + assert len(results) == len(path_task_dict) + for result in results: + assert isinstance(result, str) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_abatch(whisperlarge_model, model_name): + model = whisperlarge_model + model.name = model_name + + path_task_dict = { + file_path: "translation", + file_path2: "transcription", + } + + results = await model.abatch(path_task_dict=path_task_dict) + assert len(results) == len(path_task_dict) + for result in results: + assert isinstance(result, str)