Skip to content

Commit

Permalink
Modifying Pitchshift for faster resampling (#2441)
Browse files Browse the repository at this point in the history
Summary:
Split existing Pitchshift into multiple helper functions in order to cache kernel and speed up overall process addressing #2359.
Existing unit tests pass.

edit: functional and transforms unit test pass. Adopted lazy initialization to avoid BC-breaking.

Pull Request resolved: #2441

Reviewed By: carolineechen

Differential Revision: D36905582

Pulled By: skim0514

fbshipit-source-id: 6780db3ac8a29d59017a6abe7e82ce1fd17aaac2
  • Loading branch information
Sean Kim authored and Caroline Chen committed Jun 10, 2022
1 parent 315068f commit 633c2ab
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def test_PitchShift(self):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps)
# dry-run for initializing parameters
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)

def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torchaudio.transforms as T
from torchaudio.functional.functional import _get_sinc_resample_kernel
from parameterized import param, parameterized
from torchaudio_unittest.common_utils import (
get_spectrogram,
Expand Down Expand Up @@ -147,3 +148,21 @@ def test_mvdr(self, dtype):
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype

def test_pitch_shift_resample_kernel(self):
"""The resampling kernel in PitchShift is identical to what helper function generates.
There should be no numerical difference caused by dtype conversion.
"""
sample_rate = 8000
trans = T.PitchShift(sample_rate=sample_rate, n_steps=4)
trans.to(self.dtype).to(self.device)
# dry run to initialize the kernel
trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device))

expected, _ = _get_sinc_resample_kernel(
trans.orig_freq,
sample_rate,
trans.gcd,
device=self.device,
dtype=self.dtype)
self.assertEqual(trans.kernel, expected)
62 changes: 56 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List

import torch
import torchaudio
Expand Down Expand Up @@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
Expand Down Expand Up @@ -1635,6 +1635,39 @@ def pitch_shift(
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
waveform_stretch = _stretch_waveform(
waveform,
n_steps,
bins_per_octave,
n_fft,
win_length,
hop_length,
window,
)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)

return _fix_waveform_shape(waveform_shift, waveform.size())


def _stretch_waveform(
waveform: Tensor,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window: Optional[Tensor] = None,
) -> Tensor:
"""
Pitch shift helper function to preprocess and stretch waveform before resampling step.
Args:
See pitch_shift arg descriptions.
Returns:
Tensor: The preprocessed waveform stretched prior to resampling.
"""
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
Expand Down Expand Up @@ -1666,7 +1699,24 @@ def pitch_shift(
waveform_stretch = torch.istft(
spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
return waveform_stretch


def _fix_waveform_shape(
waveform_shift: Tensor,
shape: List[int],
) -> Tensor:
"""
PitchShift helper function to process after resampling step to fix the shape back.
Args:
waveform_shift(Tensor): The waveform after stretch and resample
shape (List[int]): The shape of initial waveform
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
ori_len = shape[-1]
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len]
Expand Down
54 changes: 50 additions & 4 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

import torch
from torch import Tensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter

from torchaudio import functional as F
from torchaudio.functional.functional import (
_apply_sinc_resample_kernel,
_get_sinc_resample_kernel,
_stretch_waveform,
_fix_waveform_shape,
)

__all__ = []
Expand Down Expand Up @@ -1511,7 +1516,7 @@ def forward(self, waveform: Tensor) -> Tensor:
)


class PitchShift(torch.nn.Module):
class PitchShift(LazyModuleMixin, torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps.
.. devices:: CPU CUDA
Expand All @@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module):
"""
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]

kernel: UninitializedParameter
width: int

def __init__(
self,
sample_rate: int,
Expand All @@ -1548,7 +1556,7 @@ def __init__(
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__()
super().__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate
Expand All @@ -1557,6 +1565,27 @@ def __init__(
self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer("window", window)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
self.orig_freq = int(sample_rate / rate)
self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))

if self.orig_freq != sample_rate:
self.width = -1
self.kernel = UninitializedParameter(device=None, dtype=None)

def initialize_parameters(self, input):
if self.has_uninitialized_params():
if self.orig_freq != self.sample_rate:
with torch.no_grad():
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
self.sample_rate,
self.gcd,
dtype=input.dtype,
device=input.device,
)
self.kernel.materialize(kernel.shape)
self.kernel.copy_(kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -1566,10 +1595,10 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
shape = waveform.size()

return F.pitch_shift(
waveform_stretch = _stretch_waveform(
waveform,
self.sample_rate,
self.n_steps,
self.bins_per_octave,
self.n_fft,
Expand All @@ -1578,6 +1607,23 @@ def forward(self, waveform: Tensor) -> Tensor:
self.window,
)

if self.orig_freq != self.sample_rate:
waveform_shift = _apply_sinc_resample_kernel(
waveform_stretch,
self.orig_freq,
self.sample_rate,
self.gcd,
self.kernel,
self.width,
)
else:
waveform_shift = waveform_stretch

return _fix_waveform_shape(
waveform_shift,
shape,
)


class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand Down

0 comments on commit 633c2ab

Please sign in to comment.