diff --git a/CHANGELOG.md b/CHANGELOG.md index 1750e2a336..7b5a6fd1a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727)) +- Added support for MP3 files to the `SpeechRecognition` task with librosa ([#726](https://github.com/PyTorchLightning/lightning-flash/pull/726)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index 029419b50b..baa3a32cc7 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -35,7 +35,7 @@ from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras if _AUDIO_AVAILABLE: - import soundfile as sf + import librosa from datasets import Dataset as HFDataset from datasets import load_dataset from transformers import Wav2Vec2CTCTokenizer @@ -44,11 +44,16 @@ class SpeechRecognitionDeserializer(Deserializer): + def __init__(self, sampling_rate: int): + super().__init__() + + self.sampling_rate = sampling_rate + def deserialize(self, sample: Any) -> Dict: encoded_with_padding = (sample + "===").encode("ascii") audio = base64.b64decode(encoded_with_padding) buffer = io.BytesIO(audio) - data, sampling_rate = sf.read(buffer) + data, sampling_rate = librosa.load(buffer, sr=self.sampling_rate) return { DefaultDataKeys.INPUT: data, DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate}, @@ -61,7 +66,8 @@ def example_input(self) -> str: class BaseSpeechRecognition: - def _load_sample(self, sample: Dict[str, Any]) -> Any: + @staticmethod + def _load_sample(sample: Dict[str, Any], sampling_rate: int) -> Any: path = sample[DefaultDataKeys.INPUT] if ( not os.path.isabs(path) @@ -69,16 +75,17 @@ def _load_sample(self, sample: Dict[str, Any]) -> Any: and "root" in sample[DefaultDataKeys.METADATA] ): path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path) - speech_array, sampling_rate = sf.read(path) + speech_array, sampling_rate = librosa.load(path, sr=sampling_rate) sample[DefaultDataKeys.INPUT] = speech_array sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate} return sample class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition): - def __init__(self, filetype: Optional[str] = None): + def __init__(self, sampling_rate: int, filetype: Optional[str] = None): super().__init__() self.filetype = filetype + self.sampling_rate = sampling_rate def load_data( self, @@ -107,32 +114,44 @@ def load_data( ] def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: - return self._load_sample(sample) + return self._load_sample(sample, self.sampling_rate) class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource): - def __init__(self): - super().__init__(filetype="csv") + def __init__(self, sampling_rate: int): + super().__init__(sampling_rate, filetype="csv") class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource): - def __init__(self): - super().__init__(filetype="json") + def __init__(self, sampling_rate: int): + super().__init__(sampling_rate, filetype="json") class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition): + def __init__(self, sampling_rate: int): + super().__init__() + + self.sampling_rate = sampling_rate + def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]: if isinstance(data, HFDataset): data = list(zip(data["file"], data["text"])) return super().load_data(data, dataset) + def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: + if isinstance(sample[DefaultDataKeys.INPUT], (str, Path)): + sample = self._load_sample(sample, self.sampling_rate) + return sample + class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition): - def __init__(self): - super().__init__(("wav", "ogg", "flac", "mat")) + def __init__(self, sampling_rate: int): + super().__init__(("wav", "ogg", "flac", "mat", "mp3")) + + self.sampling_rate = sampling_rate def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: - return self._load_sample(sample) + return self._load_sample(sample, self.sampling_rate) class SpeechRecognitionPreprocess(Preprocess): @@ -143,6 +162,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + sampling_rate: int = 16000, ): super().__init__( train_transform=train_transform, @@ -150,13 +170,13 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(), - DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(), - DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(), - DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(), + DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(sampling_rate), + DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(sampling_rate), + DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(sampling_rate), + DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(sampling_rate), }, default_data_source=DefaultDataSources.FILES, - deserializer=SpeechRecognitionDeserializer(), + deserializer=SpeechRecognitionDeserializer(sampling_rate), ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index fb866a5f84..592a3c7b52 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool: _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") -_SOUNDFILE_AVAILABLE = _module_available("soundfile") +_LIBROSA_AVAILABLE = _module_available("librosa") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") @@ -148,7 +148,7 @@ class Image(metaclass=MetaImage): ) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE -_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) +_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE _EXTRAS_AVAILABLE = { diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 4c198da250..2353d8d18f 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,4 +1,4 @@ torchaudio -soundfile>=0.10.2 +librosa>=0.8.1 transformers>=4.5 datasets>=1.8