Skip to content

Commit

Permalink
Use FFmpeg-based I/O as fallback in sox_io backend (pytorch#2419)
Browse files Browse the repository at this point in the history
Summary:
This commit add fallback mechanism to `info` and `load` functions of sox_io backend.
If torchaudio is compiled to use FFmpeg, and runtime dependencies are properly loaded,
in case `info` and `load` fail, it fallback to FFmpeg-based implementation.

Depends on pytorch#2416, pytorch#2417, pytorch#2418

Pull Request resolved: pytorch#2419

Differential Revision: D36740306

Pulled By: mthrok

fbshipit-source-id: b933f5a55ec5f5bb8d6ec4abbf1ad796839258d0
  • Loading branch information
mthrok authored and facebook-github-bot committed May 29, 2022
1 parent bb77cbe commit e010fdd
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 63 deletions.
16 changes: 12 additions & 4 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"
Expand Down Expand Up @@ -355,6 +355,14 @@ def _gen_comment_file(self, comments):
return comment_path


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
Expand Down Expand Up @@ -435,7 +443,7 @@ def test_fileobj_large_header(self, ext, dtype):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
Expand Down Expand Up @@ -545,7 +553,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
return sox_io_backend.info(Unseekable(resp.raw), format=format_)

@parameterized.expand(
[
Expand Down Expand Up @@ -583,5 +591,5 @@ def test_info_fail(self):
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.info(path)
33 changes: 22 additions & 11 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
ffmpeg_utils,
get_asset_path,
get_wav_data,
HttpServerMixin,
Expand Down Expand Up @@ -81,7 +82,10 @@ def assert_format(
)
# 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(path, ref_path)
else:
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
# 3. Load the given format with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
Expand Down Expand Up @@ -377,14 +381,12 @@ def test_mp3(self):
class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b""

def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret
def read(self, _):
return self.fileobj.read(2)

def seek(self, offset, whence):
return self.fileobj.seek(offset, whence)


@skipIfNoSox
Expand Down Expand Up @@ -557,6 +559,14 @@ def test_tarfile(self, ext, kwargs):
self.assertEqual(expected, found)


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
Expand Down Expand Up @@ -587,10 +597,11 @@ def test_requests(self, ext, kwargs):

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)
if ext != "mp3":
self.assertEqual(expected, found)

@parameterized.expand(
list(
Expand Down Expand Up @@ -627,5 +638,5 @@ def test_load_fail(self):
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.load(path)
13 changes: 10 additions & 3 deletions test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
ffmpeg_utils,
get_wav_data,
load_wav,
nested_params,
Expand Down Expand Up @@ -130,7 +131,10 @@ def assert_save_consistency(
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(tgt_path, tst_path)
else:
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]

Expand All @@ -140,7 +144,10 @@ def assert_save_consistency(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(sox_path, ref_path)
else:
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]

Expand Down Expand Up @@ -437,5 +444,5 @@ def test_save_fail(self):
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
11 changes: 0 additions & 11 deletions test/torchaudio_unittest/backend/sox_io/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import io
import itertools
import unittest

from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
skipIfNoSox,
Expand All @@ -16,12 +13,6 @@
from .common import name_func


skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3',
)


@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
Expand Down Expand Up @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down Expand Up @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/common_utils/ffmpeg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import subprocess
import sys


def convert_to_wav(src_path, dst_path):
"""Convert audio file with `ffmpeg` command."""
# TODO: parameterize codec
command = ["ffmpeg", "-y", "-i", src_path, "-c:a", "pcm_f32le", dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
65 changes: 59 additions & 6 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,52 @@
from .common import AudioMetaData


# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _alt_info(filepath: str, format: Optional[str]) -> AudioMetaData:
return AudioMetaData(*torch.ops.torchaudio.ffmpeg_get_audio_info(filepath, format))


def _alt_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData:
return AudioMetaData(*torchaudio._torchaudio_ffmpeg.get_audio_info_fileobj(fileobj, format))


# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_info(filepath: str, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))


def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(fileobj))


# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))


def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")


if torchaudio._extension._FFMPEG_INITIALIZED:
_fallback_info = _alt_info
_fallback_info_fileobj = _alt_info_fileobj
_fallback_load = torch.ops.torchaudio.ffmpeg_load_audio
_fallback_load_fileobj = torchaudio._torchaudio_ffmpeg.load_audio_fileobj
else:
_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_filebj = _fail_load_fileobj


@_mod_utils.requires_sox()
def info(
filepath: str,
Expand Down Expand Up @@ -46,11 +92,14 @@ def info(
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
assert sinfo is not None # for TorchScript compatibility
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info(filepath, format)


@_mod_utils.requires_sox()
Expand Down Expand Up @@ -145,15 +194,19 @@ def load(
"""
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj(
ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
assert ret is not None # for TorchScript compatibility
return ret
if ret is not None:
return ret
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)


@_mod_utils.requires_sox()
Expand Down
8 changes: 8 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ namespace ffmpeg {
namespace {

PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
m.def(
"load_audio_fileobj",
&torchaudio::ffmpeg::load_audio_fileobj,
"Load audio from file object.");
m.def(
"get_audio_info_fileobj",
&torchaudio::ffmpeg::get_audio_info_fileobj,
"Get metadata of audio in file object.");
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
.def(py::init<
Expand Down
32 changes: 32 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,37 @@ StreamReaderFileObj::StreamReaderFileObj(
option.value_or(OptionDict{}),
pAVIO)) {}

std::tuple<c10::optional<torch::Tensor>, int64_t> load_audio_fileobj(
py::object fileobj,
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames,
bool convert,
bool channels_first,
const c10::optional<std::string>& format) {
FileObj f{fileobj, 4086};
return load_audio(
get_input_format_context(
static_cast<std::string>(py::str(fileobj.attr("__str__")())),
format,
{},
f.pAVIO),
frame_offset,
num_frames,
convert,
channels_first,
format);
}

MetaDataTuple get_audio_info_fileobj(
py::object fileobj,
c10::optional<std::string> format) {
FileObj f{fileobj, 4086};
return get_audio_info(get_input_format_context(
static_cast<std::string>(py::str(fileobj.attr("__str__")())),
format,
{},
f.pAVIO));
}

} // namespace ffmpeg
} // namespace torchaudio
15 changes: 15 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,20 @@ class StreamReaderFileObj : protected FileObj, public StreamReaderBinding {
int64_t buffer_size);
};

std::tuple<c10::optional<torch::Tensor>, int64_t> load_audio_fileobj(
py::object fileobj,
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames,
bool convert,
bool channels_first,
const c10::optional<std::string>& format);

using MetaDataTuple =
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;

MetaDataTuple get_audio_info_fileobj(
py::object fileobj,
c10::optional<std::string> format);

} // namespace ffmpeg
} // namespace torchaudio
Loading

0 comments on commit e010fdd

Please sign in to comment.