Skip to content

Commit

Permalink
Remove file-like object support from sox_io backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 5, 2023
1 parent 79b5305 commit 6e986be
Show file tree
Hide file tree
Showing 23 changed files with 81 additions and 1,740 deletions.
268 changes: 1 addition & 267 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import io
import itertools
import os
import tarfile
from contextlib import contextmanager

from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio_unittest.backend.common import get_bits_per_sample, get_encoding
from torchaudio_unittest.backend.common import get_encoding
from torchaudio_unittest.common_utils import (
get_asset_path,
get_wav_data,
HttpServerMixin,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
sox_utils,
TempDirMixin,
Expand All @@ -25,10 +17,6 @@
from .common import name_func


if _mod_utils.is_module_available("requests"):
import requests


@skipIfNoExec("sox")
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
Expand Down Expand Up @@ -338,260 +326,6 @@ def test_mp3(self):
assert sinfo.encoding == "MP3"


class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None

sox_utils.gen_audio_file(
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path

def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_)

def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)

def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)

@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)

@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)

bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)

@parameterized.expand(
[
("vorbis", "float32"),
]
)
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["vorbis"] else num_frames

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)

@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)

bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)

@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)

bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 1728}.get(ext, num_frames)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)

@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)

bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)


@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)

url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(Unseekable(resp.raw), format=format_)

@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)

bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)


@skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase):
def test_info_fail(self):
Expand Down
Loading

0 comments on commit 6e986be

Please sign in to comment.