Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - WhisperLargeModel implemented #754

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions pkgs/swarmauri/swarmauri/llms/concrete/WhisperLargeModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import List, Literal, Dict
import requests
import asyncio
import aiohttp
from swarmauri.llms.base.LLMBase import LLMBase


class WhisperLargeModel(LLMBase):
"""
Implementation of Whisper Large V3 model using HuggingFace's Inference API
https://huggingface.co/openai/whisper-large-v3
"""

allowed_models: List[str] = ["openai/whisper-large-v3"]
name: str = "openai/whisper-large-v3"
type: Literal["WhisperLargeModel"] = "WhisperLargeModel"
api_key: str

__API_URL: str = "https://api-inference.huggingface.co/models/openai/whisper-large-v3"

def predict(
self,
audio_path: str,
task: Literal["transcription", "translation"] = "transcription",
) -> str:
"""
Process a single audio file using the Hugging Face Inference API

Args:
audio_path (str): Path to the audio file
task (Literal["transcription", "translation"]): Task to perform

Returns:
str: Transcribed or translated text
"""
headers = {"Authorization": f"Bearer {self.api_key}"}

if task not in ["transcription", "translation"]:
raise ValueError(
f"Task {task} not supported. Choose from ['transcription', 'translation']"
)

with open(audio_path, "rb") as audio_file:
data = audio_file.read()

params = {"task": task}
if task == "translation":
params["language"] = "en"

response = requests.post(
self.__API_URL, headers=headers, data=data, params=params
)

if response.status_code != 200:
raise Exception(
f"API request failed with status code {response.status_code}: {response.text}"
)

result = response.json()

if isinstance(result, dict):
return result.get("text", "")
elif isinstance(result, list) and len(result) > 0:
return result[0].get("text", "")
else:
raise Exception("Unexpected API response format")

async def apredict(
self,
audio_path: str,
task: Literal["transcription", "translation"] = "transcription",
) -> str:
"""
Asynchronously process a single audio file
"""
headers = {"Authorization": f"Bearer {self.api_key}"}

if task not in ["transcription", "translation"]:
raise ValueError(
f"Task {task} not supported. Choose from ['transcription', 'translation']"
)

with open(audio_path, "rb") as audio_file:
data = audio_file.read()

params = {"task": task}
if task == "translation":
params["language"] = "en"

async with aiohttp.ClientSession() as session:
async with session.post(
self.__API_URL, headers=headers, data=data, params=params
) as response:
if response.status != 200:
raise Exception(
f"API request failed with status code {response.status}: {await response.text()}"
)

result = await response.json()

if isinstance(result, dict):
return result.get("text", "")
elif isinstance(result, list) and len(result) > 0:
return result[0].get("text", "")
else:
raise Exception("Unexpected API response format")

def batch(
self,
path_task_dict: Dict[str, Literal["transcription", "translation"]],
) -> List[str]:
"""
Synchronously process multiple audio files
"""
return [
self.predict(audio_path=path, task=task)
for path, task in path_task_dict.items()
]

async def abatch(
self,
path_task_dict: Dict[str, Literal["transcription", "translation"]],
max_concurrent: int = 5,
) -> List[str]:
"""
Process multiple audio files in parallel with controlled concurrency

Args:
path_task_dict (Dict[str, Literal["transcription", "translation"]]):
Dictionary mapping file paths to tasks
max_concurrent (int): Maximum number of concurrent requests

Returns:
List[str]: List of transcribed/translated texts
"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_audio(path: str, task: str) -> str:
async with semaphore:
return await self.apredict(audio_path=path, task=task)

tasks = [process_audio(path, task) for path, task in path_task_dict.items()]

return await asyncio.gather(*tasks)
141 changes: 141 additions & 0 deletions pkgs/swarmauri/tests/unit/llms/WhisperLargeModel_unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
import pytest
import os
from swarmauri.llms.concrete.WhisperLargeModel import WhisperLargeModel as LLM
from swarmauri.utils.timeout_wrapper import timeout
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("HUGGINGFACE_TOKEN")

# Get the current working directory
root_dir = Path(__file__).resolve().parents[2]

# Construct file paths dynamically
file_path = os.path.join(root_dir, "static", "test.mp3")
file_path2 = os.path.join(root_dir, "static", "test_fr.mp3")


@pytest.fixture(scope="module")
def whisperlarge_model():
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@timeout(5)
@pytest.mark.unit
def test_ubc_resource(whisperlarge_model):
assert whisperlarge_model.resource == "LLM"


@timeout(5)
@pytest.mark.unit
def test_ubc_type(whisperlarge_model):
assert whisperlarge_model.type == "WhisperLargeModel"


@timeout(5)
@pytest.mark.unit
def test_serialization(whisperlarge_model):
assert whisperlarge_model.id == LLM.model_validate_json(whisperlarge_model.model_dump_json()).id


@timeout(5)
@pytest.mark.unit
def test_default_name(whisperlarge_model):
assert whisperlarge_model.name == "openai/whisper-large-v3"


@timeout(5)
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_audio_transcription(whisperlarge_model, model_name):
model = whisperlarge_model
model.name = model_name

prediction = model.predict(audio_path=file_path)

logging.info(prediction)

assert type(prediction) is str


@timeout(5)
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_audio_translation(whisperlarge_model, model_name):
model = whisperlarge_model
model.name = model_name

prediction = model.predict(
audio_path=file_path,
task="translation",
)

logging.info(prediction)

assert type(prediction) is str


@timeout(5)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
async def test_apredict(whisperlarge_model, model_name):
whisperlarge_model.name = model_name

prediction = await whisperlarge_model.apredict(
audio_path=file_path,
task="translation",
)

logging.info(prediction)
assert type(prediction) is str


@timeout(5)
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_batch(whisperlarge_model, model_name):
model = whisperlarge_model
model.name = model_name

path_task_dict = {
file_path: "translation",
file_path2: "transcription",
}

results = model.batch(path_task_dict=path_task_dict)
assert len(results) == len(path_task_dict)
for result in results:
assert isinstance(result, str)


@timeout(5)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
async def test_abatch(whisperlarge_model, model_name):
model = whisperlarge_model
model.name = model_name

path_task_dict = {
file_path: "translation",
file_path2: "transcription",
}

results = await model.abatch(path_task_dict=path_task_dict)
assert len(results) == len(path_task_dict)
for result in results:
assert isinstance(result, str)