Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
replace soundfile with librosa (#726)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
  • Loading branch information
4 people authored Sep 6, 2021
1 parent cf86275 commit 4dd2830
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))

- Added support for MP3 files to the `SpeechRecognition` task with librosa ([#726](https://github.com/PyTorchLightning/lightning-flash/pull/726))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
56 changes: 38 additions & 18 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras

if _AUDIO_AVAILABLE:
import soundfile as sf
import librosa
from datasets import Dataset as HFDataset
from datasets import load_dataset
from transformers import Wav2Vec2CTCTokenizer
Expand All @@ -44,11 +44,16 @@


class SpeechRecognitionDeserializer(Deserializer):
def __init__(self, sampling_rate: int):
super().__init__()

self.sampling_rate = sampling_rate

def deserialize(self, sample: Any) -> Dict:
encoded_with_padding = (sample + "===").encode("ascii")
audio = base64.b64decode(encoded_with_padding)
buffer = io.BytesIO(audio)
data, sampling_rate = sf.read(buffer)
data, sampling_rate = librosa.load(buffer, sr=self.sampling_rate)
return {
DefaultDataKeys.INPUT: data,
DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate},
Expand All @@ -61,24 +66,26 @@ def example_input(self) -> str:


class BaseSpeechRecognition:
def _load_sample(self, sample: Dict[str, Any]) -> Any:
@staticmethod
def _load_sample(sample: Dict[str, Any], sampling_rate: int) -> Any:
path = sample[DefaultDataKeys.INPUT]
if (
not os.path.isabs(path)
and DefaultDataKeys.METADATA in sample
and "root" in sample[DefaultDataKeys.METADATA]
):
path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path)
speech_array, sampling_rate = sf.read(path)
speech_array, sampling_rate = librosa.load(path, sr=sampling_rate)
sample[DefaultDataKeys.INPUT] = speech_array
sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate}
return sample


class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition):
def __init__(self, filetype: Optional[str] = None):
def __init__(self, sampling_rate: int, filetype: Optional[str] = None):
super().__init__()
self.filetype = filetype
self.sampling_rate = sampling_rate

def load_data(
self,
Expand Down Expand Up @@ -107,32 +114,44 @@ def load_data(
]

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
return self._load_sample(sample)
return self._load_sample(sample, self.sampling_rate)


class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource):
def __init__(self):
super().__init__(filetype="csv")
def __init__(self, sampling_rate: int):
super().__init__(sampling_rate, filetype="csv")


class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource):
def __init__(self):
super().__init__(filetype="json")
def __init__(self, sampling_rate: int):
super().__init__(sampling_rate, filetype="json")


class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition):
def __init__(self, sampling_rate: int):
super().__init__()

self.sampling_rate = sampling_rate

def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]:
if isinstance(data, HFDataset):
data = list(zip(data["file"], data["text"]))
return super().load_data(data, dataset)

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
if isinstance(sample[DefaultDataKeys.INPUT], (str, Path)):
sample = self._load_sample(sample, self.sampling_rate)
return sample


class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition):
def __init__(self):
super().__init__(("wav", "ogg", "flac", "mat"))
def __init__(self, sampling_rate: int):
super().__init__(("wav", "ogg", "flac", "mat", "mp3"))

self.sampling_rate = sampling_rate

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
return self._load_sample(sample)
return self._load_sample(sample, self.sampling_rate)


class SpeechRecognitionPreprocess(Preprocess):
Expand All @@ -143,20 +162,21 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
sampling_rate: int = 16000,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(),
DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(),
DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(),
DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(),
DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(sampling_rate),
DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(sampling_rate),
DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(sampling_rate),
DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(sampling_rate),
},
default_data_source=DefaultDataSources.FILES,
deserializer=SpeechRecognitionDeserializer(),
deserializer=SpeechRecognitionDeserializer(sampling_rate),
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_SOUNDFILE_AVAILABLE = _module_available("soundfile")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
Expand Down Expand Up @@ -148,7 +148,7 @@ class Image(metaclass=MetaImage):
)
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

_EXTRAS_AVAILABLE = {
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_audio.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchaudio
soundfile>=0.10.2
librosa>=0.8.1
transformers>=4.5
datasets>=1.8

0 comments on commit 4dd2830

Please sign in to comment.