Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

replace soundfile with librosa #726

Merged
merged 10 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))

- Added support for MP3 files to the `SpeechRecognition` task with librosa ([#726](https://github.com/PyTorchLightning/lightning-flash/pull/726))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
56 changes: 38 additions & 18 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras

if _AUDIO_AVAILABLE:
import soundfile as sf
import librosa
from datasets import Dataset as HFDataset
from datasets import load_dataset
from transformers import Wav2Vec2CTCTokenizer
Expand All @@ -44,11 +44,16 @@


class SpeechRecognitionDeserializer(Deserializer):
def __init__(self, sampling_rate: int):
super().__init__()

self.sampling_rate = sampling_rate

def deserialize(self, sample: Any) -> Dict:
encoded_with_padding = (sample + "===").encode("ascii")
audio = base64.b64decode(encoded_with_padding)
buffer = io.BytesIO(audio)
data, sampling_rate = sf.read(buffer)
data, sampling_rate = librosa.load(buffer, sr=self.sampling_rate)
return {
DefaultDataKeys.INPUT: data,
DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate},
Expand All @@ -61,24 +66,26 @@ def example_input(self) -> str:


class BaseSpeechRecognition:
def _load_sample(self, sample: Dict[str, Any]) -> Any:
@staticmethod
def _load_sample(sample: Dict[str, Any], sampling_rate: int) -> Any:
path = sample[DefaultDataKeys.INPUT]
if (
not os.path.isabs(path)
and DefaultDataKeys.METADATA in sample
and "root" in sample[DefaultDataKeys.METADATA]
):
path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path)
speech_array, sampling_rate = sf.read(path)
speech_array, sampling_rate = librosa.load(path, sr=sampling_rate)
sample[DefaultDataKeys.INPUT] = speech_array
sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate}
return sample


class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition):
def __init__(self, filetype: Optional[str] = None):
def __init__(self, sampling_rate: int, filetype: Optional[str] = None):
super().__init__()
self.filetype = filetype
self.sampling_rate = sampling_rate

def load_data(
self,
Expand Down Expand Up @@ -107,32 +114,44 @@ def load_data(
]

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
return self._load_sample(sample)
return self._load_sample(sample, self.sampling_rate)


class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource):
def __init__(self):
super().__init__(filetype="csv")
def __init__(self, sampling_rate: int):
super().__init__(sampling_rate, filetype="csv")


class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource):
def __init__(self):
super().__init__(filetype="json")
def __init__(self, sampling_rate: int):
super().__init__(sampling_rate, filetype="json")


class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition):
def __init__(self, sampling_rate: int):
super().__init__()

self.sampling_rate = sampling_rate

def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]:
if isinstance(data, HFDataset):
data = list(zip(data["file"], data["text"]))
return super().load_data(data, dataset)

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
if isinstance(sample[DefaultDataKeys.INPUT], (str, Path)):
sample = self._load_sample(sample, self.sampling_rate)
return sample


class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition):
def __init__(self):
super().__init__(("wav", "ogg", "flac", "mat"))
def __init__(self, sampling_rate: int):
super().__init__(("wav", "ogg", "flac", "mat", "mp3"))

self.sampling_rate = sampling_rate

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
return self._load_sample(sample)
return self._load_sample(sample, self.sampling_rate)


class SpeechRecognitionPreprocess(Preprocess):
Expand All @@ -143,20 +162,21 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
sampling_rate: int = 16000,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(),
DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(),
DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(),
DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(),
DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(sampling_rate),
DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(sampling_rate),
DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(sampling_rate),
DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(sampling_rate),
},
default_data_source=DefaultDataSources.FILES,
deserializer=SpeechRecognitionDeserializer(),
deserializer=SpeechRecognitionDeserializer(sampling_rate),
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_SOUNDFILE_AVAILABLE = _module_available("soundfile")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
Expand Down Expand Up @@ -148,7 +148,7 @@ class Image(metaclass=MetaImage):
)
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

_EXTRAS_AVAILABLE = {
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_audio.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchaudio
soundfile>=0.10.2
librosa>=0.8.1
transformers>=4.5
datasets>=1.8