Skip to content

Commit

Permalink
Support compression level in i/o dispatcher backend (#3662)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3662

Differential Revision: D50367721

fbshipit-source-id: fbb6d62a6bdf6a55c98a8e94d08e344b8377f1dc
  • Loading branch information
hwangjeff authored and facebook-github-bot committed Oct 18, 2023
1 parent 671261c commit 025314e
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src/torchaudio/_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import BinaryIO, Optional, Tuple, Union

from torch import Tensor
from torchaudio.io import CodecConfig

from .common import AudioMetaData

Expand Down Expand Up @@ -37,6 +38,7 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
raise NotImplementedError

Expand Down
9 changes: 9 additions & 0 deletions src/torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def save_audio(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
Expand All @@ -250,6 +251,7 @@ def save_audio(
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
Expand Down Expand Up @@ -304,7 +306,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
Expand All @@ -314,6 +322,7 @@ def save(
encoding,
bits_per_sample,
buffer_size,
compression,
)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions src/torchaudio/_backend/soundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import BinaryIO, Optional, Tuple, Union

import torch
from torchaudio.io import CodecConfig

from . import soundfile_backend
from .backend import Backend
Expand Down Expand Up @@ -35,7 +36,11 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")

soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
)
Expand Down
8 changes: 7 additions & 1 deletion src/torchaudio/_backend/sox.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (float, int, type(None))):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float` or `int`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"):
raise ValueError(
"SoX backend does not support writing to file-like objects. ",
Expand All @@ -68,7 +74,7 @@ def save(
src,
sample_rate,
channels_first,
None,
compression,
format,
encoding,
bits_per_sample,
Expand Down
9 changes: 8 additions & 1 deletion src/torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from torchaudio.io import CodecConfig

from . import soundfile_backend

Expand Down Expand Up @@ -229,6 +230,7 @@ def save(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float, int]] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -283,8 +285,13 @@ def save(
.. seealso::
:ref:`backend`
compression (CodecConfig, float, int, or None, optional):
To fill in.
"""
backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size)
return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)

return save
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_save(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None)

@parameterized.expand(
[
Expand All @@ -126,4 +126,4 @@ def test_save_fileobj(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None)
53 changes: 47 additions & 6 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io import CodecConfig

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand All @@ -25,7 +26,7 @@
)


def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None):
def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None, compression=None):
command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"]
if muxer:
command += ["-f", muxer]
Expand All @@ -45,6 +46,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: CodecConfig = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -104,14 +106,23 @@ def assert_save_consistency(
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -123,6 +134,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand Down Expand Up @@ -198,11 +210,27 @@ def test_save_wav_dtype(self, test_mode, params):
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[16, 24],
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample):
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode)
codec_config = CodecConfig(
compression_level=compression_level,
)
self.assert_save_consistency(
"flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode
)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand All @@ -212,12 +240,25 @@ def test_save_flac(self, test_mode, bits_per_sample):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)

@nested_params(
[
None,
-1,
0,
1,
2,
3,
5,
10,
],
["path", "fileobj", "bytesio"],
)
def test_save_vorbis(self, test_mode):
def test_save_vorbis(self, quality_level, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode)
codec_config = CodecConfig(
qscale=quality_level,
)
self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand Down
63 changes: 55 additions & 8 deletions test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -101,13 +102,16 @@ def assert_save_consistency(
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -118,6 +122,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -134,7 +139,9 @@ def assert_save_consistency(

# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
Expand Down Expand Up @@ -175,15 +182,42 @@ def test_save_wav_dtype(self, params):

@nested_params(
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path"
)

def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1)

def test_save_vorbis(self):
self.assert_save_consistency("vorbis", test_mode="path")
@nested_params(
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path")

@nested_params(
[
Expand Down Expand Up @@ -254,9 +288,22 @@ def test_save_amb(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")

@nested_params(
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
@skipIfNoSoxEncoder("amr-nb")
def test_save_amr_nb(self):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path")
def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path")

def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
Expand Down

0 comments on commit 025314e

Please sign in to comment.