-
-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Nov 4, 2023
1 parent
75ebbe0
commit 7e1d486
Showing
2 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# test_distilled_whisperx.py | ||
|
||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | ||
|
||
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry | ||
|
||
|
||
# Fixtures for setting up model, processor, and audio files | ||
@pytest.fixture(scope="module") | ||
def model_id(): | ||
return "distil-whisper/distil-large-v2" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def whisper_model(model_id): | ||
return DistilWhisperModel(model_id) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def audio_file_path(tmp_path_factory): | ||
# You would create a small temporary MP3 file here for testing | ||
# or use a public domain MP3 file's path | ||
return "path/to/valid_audio.mp3" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def invalid_audio_file_path(): | ||
return "path/to/invalid_audio.mp3" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def audio_dict(): | ||
# This should represent a valid audio dictionary as expected by the model | ||
return {"array": torch.randn(1, 16000), "sampling_rate": 16000} | ||
|
||
|
||
# Test initialization | ||
def test_initialization(whisper_model): | ||
assert whisper_model.model is not None | ||
assert whisper_model.processor is not None | ||
|
||
|
||
# Test successful transcription with file path | ||
def test_transcribe_with_file_path(whisper_model, audio_file_path): | ||
transcription = whisper_model.transcribe(audio_file_path) | ||
assert isinstance(transcription, str) | ||
|
||
|
||
# Test successful transcription with audio dict | ||
def test_transcribe_with_audio_dict(whisper_model, audio_dict): | ||
transcription = whisper_model.transcribe(audio_dict) | ||
assert isinstance(transcription, str) | ||
|
||
|
||
# Test for file not found error | ||
def test_file_not_found(whisper_model, invalid_audio_file_path): | ||
with pytest.raises(Exception): | ||
whisper_model.transcribe(invalid_audio_file_path) | ||
|
||
|
||
# Asynchronous tests | ||
@pytest.mark.asyncio | ||
async def test_async_transcription_success(whisper_model, audio_file_path): | ||
transcription = await whisper_model.async_transcribe(audio_file_path) | ||
assert isinstance(transcription, str) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): | ||
with pytest.raises(Exception): | ||
await whisper_model.async_transcribe(invalid_audio_file_path) | ||
|
||
|
||
# Testing real-time transcription simulation | ||
def test_real_time_transcription(whisper_model, audio_file_path, capsys): | ||
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) | ||
captured = capsys.readouterr() | ||
assert "Starting real-time transcription..." in captured.out | ||
|
||
|
||
# Testing retry decorator for asynchronous function | ||
@pytest.mark.asyncio | ||
async def test_async_retry(): | ||
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0) | ||
async def failing_func(): | ||
raise ValueError("Test") | ||
|
||
with pytest.raises(ValueError): | ||
await failing_func() | ||
|
||
|
||
# Mocking the actual model to avoid GPU/CPU intensive operations during test | ||
@pytest.fixture | ||
def mocked_model(monkeypatch): | ||
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) | ||
processor_mock = MagicMock(AutoProcessor) | ||
monkeypatch.setattr( | ||
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", | ||
model_mock, | ||
) | ||
monkeypatch.setattr( | ||
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock | ||
) | ||
return model_mock, processor_mock | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): | ||
model_mock, processor_mock = mocked_model | ||
# Set up what the mock should return when it's called | ||
model_mock.return_value.generate.return_value = torch.tensor([[0]]) | ||
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] | ||
model_wrapper = DistilWhisperModel() | ||
transcription = await model_wrapper.async_transcribe(audio_file_path) | ||
assert transcription == "mocked transcription" | ||
|